3 | | {{{#!diff |
4 | | diff --git a/mediagoblin/plugins/oauth/__init__.py b/mediagoblin/plugins/oauth/__init__.py |
5 | | index 4714d95..5762379 100644 |
6 | | --- a/mediagoblin/plugins/oauth/__init__.py |
7 | | +++ b/mediagoblin/plugins/oauth/__init__.py |
8 | | @@ -34,7 +34,7 @@ def setup_plugin(): |
9 | | _log.debug('OAuth config: {0}'.format(config)) |
10 | | |
11 | | routes = [ |
12 | | - ('mediagoblin.plugins.oauth.authorize', |
13 | | + ('mediagoblin.plugins.oauth.authorize', |
14 | | '/oauth/authorize', |
15 | | 'mediagoblin.plugins.oauth.views:authorize'), |
16 | | ('mediagoblin.plugins.oauth.authorize_client', |
17 | | diff --git a/mediagoblin/plugins/oauth/migrations.py b/mediagoblin/plugins/oauth/migrations.py |
18 | | index 6aa0d7c..f70a2e8 100644 |
19 | | --- a/mediagoblin/plugins/oauth/migrations.py |
20 | | +++ b/mediagoblin/plugins/oauth/migrations.py |
21 | | @@ -102,6 +102,22 @@ class OAuthCode_v0(declarative_base()): |
22 | | client_id = Column(Integer, ForeignKey(OAuthClient_v0.id), nullable=False) |
23 | | |
24 | | |
25 | | +class OAuthRefreshToken_v0(declarative_base()): |
26 | | + __tablename__ = 'oauth__refresh_tokens' |
27 | | + |
28 | | + id = Column(Integer, primary_key=True) |
29 | | + created = Column(DateTime, nullable=False, |
30 | | + default=datetime.now) |
31 | | + |
32 | | + token = Column(Unicode, index=True) |
33 | | + |
34 | | + user_id = Column(Integer, ForeignKey(User.id), nullable=False, |
35 | | + index=True) |
36 | | + |
37 | | + # XXX: Is it OK to use OAuthClient_v0.id in this way? |
38 | | + client_id = Column(Integer, ForeignKey(OAuthClient_v0.id), nullable=False) |
39 | | + |
40 | | + |
41 | | @RegisterMigration(1, MIGRATIONS) |
42 | | def remove_and_replace_token_and_code(db): |
43 | | metadata = MetaData(bind=db.bind) |
44 | | @@ -122,3 +138,22 @@ def remove_and_replace_token_and_code(db): |
45 | | OAuthCode_v0.__table__.create(db.bind) |
46 | | |
47 | | db.commit() |
48 | | + |
49 | | + |
50 | | +@RegisterMigration(2, MIGRATIONS) |
51 | | +def remove_refresh_token_field(db): |
52 | | + metadata = MetaData(bind=db.bind) |
53 | | + |
54 | | + token_table = Table('oauth__tokens', metadata, autoload=True, |
55 | | + autoload_with=db.bind) |
56 | | + |
57 | | + refresh_token = token_table.columns['refresh_token'] |
58 | | + |
59 | | + refresh_token.drop() |
60 | | + db.commit() |
61 | | + |
62 | | +@RegisterMigration(3, MIGRATIONS) |
63 | | +def create_refresh_token_table(db): |
64 | | + OAuthRefreshToken_v0.__table__.create(db.bind) |
65 | | + |
66 | | + db.commit() |
67 | | diff --git a/mediagoblin/plugins/oauth/models.py b/mediagoblin/plugins/oauth/models.py |
68 | | index 695dad3..28735dd 100644 |
69 | | --- a/mediagoblin/plugins/oauth/models.py |
70 | | +++ b/mediagoblin/plugins/oauth/models.py |
71 | | @@ -19,12 +19,14 @@ import bcrypt |
72 | | |
73 | | from datetime import datetime, timedelta |
74 | | |
75 | | -from mediagoblin.db.base import Base |
76 | | -from mediagoblin.db.models import User |
77 | | |
78 | | from sqlalchemy import ( |
79 | | Column, Unicode, Integer, DateTime, ForeignKey, Enum) |
80 | | from sqlalchemy.orm import relationship |
81 | | +from mediagoblin.db.base import Base |
82 | | +from mediagoblin.db.models import User |
83 | | +from mediagoblin.plugins.oauth.tools import generate_identifier, \ |
84 | | + generate_secret, generate_token, generate_code, generate_refresh_token |
85 | | |
86 | | # Don't remove this, I *think* it applies sqlalchemy-migrate functionality onto |
87 | | # the models. |
88 | | @@ -41,8 +43,9 @@ class OAuthClient(Base): |
89 | | name = Column(Unicode) |
90 | | description = Column(Unicode) |
91 | | |
92 | | - identifier = Column(Unicode, unique=True, index=True) |
93 | | - secret = Column(Unicode, index=True) |
94 | | + identifier = Column(Unicode, unique=True, index=True, |
95 | | + default=generate_identifier) |
96 | | + secret = Column(Unicode, index=True, default=generate_secret) |
97 | | |
98 | | owner_id = Column(Integer, ForeignKey(User.id)) |
99 | | owner = relationship(User, backref='registered_clients') |
100 | | @@ -54,14 +57,8 @@ class OAuthClient(Base): |
101 | | u'public', |
102 | | name=u'oauth__client_type')) |
103 | | |
104 | | - def generate_identifier(self): |
105 | | - self.identifier = unicode(uuid.uuid4()) |
106 | | - |
107 | | - def generate_secret(self): |
108 | | - self.secret = unicode( |
109 | | - bcrypt.hashpw( |
110 | | - unicode(uuid.uuid4()), |
111 | | - bcrypt.gensalt())) |
112 | | + def update_secret(self): |
113 | | + self.secret = generate_secret() |
114 | | |
115 | | def __repr__(self): |
116 | | return '<{0} {1}:{2} ({3})>'.format( |
117 | | @@ -76,10 +73,10 @@ class OAuthUserClient(Base): |
118 | | id = Column(Integer, primary_key=True) |
119 | | |
120 | | user_id = Column(Integer, ForeignKey(User.id)) |
121 | | - user = relationship(User, backref='oauth_clients') |
122 | | + user = relationship(User, backref='oauth_client_relations') |
123 | | |
124 | | client_id = Column(Integer, ForeignKey(OAuthClient.id)) |
125 | | - client = relationship(OAuthClient, backref='users') |
126 | | + client = relationship(OAuthClient, backref='oauth_user_relations') |
127 | | |
128 | | state = Column(Enum( |
129 | | u'approved', |
130 | | @@ -103,8 +100,7 @@ class OAuthToken(Base): |
131 | | default=datetime.now) |
132 | | expires = Column(DateTime, nullable=False, |
133 | | default=lambda: datetime.now() + timedelta(days=30)) |
134 | | - token = Column(Unicode, index=True) |
135 | | - refresh_token = Column(Unicode, index=True) |
136 | | + token = Column(Unicode, index=True, default=generate_token) |
137 | | |
138 | | user_id = Column(Integer, ForeignKey(User.id), nullable=False, |
139 | | index=True) |
140 | | @@ -121,6 +117,31 @@ class OAuthToken(Base): |
141 | | self.user, |
142 | | self.client) |
143 | | |
144 | | +class OAuthRefreshToken(Base): |
145 | | + __tablename__ = 'oauth__refresh_tokens' |
146 | | + |
147 | | + id = Column(Integer, primary_key=True) |
148 | | + created = Column(DateTime, nullable=False, |
149 | | + default=datetime.now) |
150 | | + |
151 | | + token = Column(Unicode, index=True, |
152 | | + default=generate_refresh_token) |
153 | | + |
154 | | + user_id = Column(Integer, ForeignKey(User.id), nullable=False, |
155 | | + index=True) |
156 | | + |
157 | | + user = relationship(User) |
158 | | + |
159 | | + client_id = Column(Integer, ForeignKey(OAuthClient.id), nullable=False) |
160 | | + client = relationship(OAuthClient) |
161 | | + |
162 | | + def __repr__(self): |
163 | | + return '<{0} #{1} [{3}, {4}]>'.format( |
164 | | + self.__class__.__name__, |
165 | | + self.id, |
166 | | + self.user, |
167 | | + self.client) |
168 | | + |
169 | | |
170 | | class OAuthCode(Base): |
171 | | __tablename__ = 'oauth__codes' |
172 | | @@ -130,7 +151,7 @@ class OAuthCode(Base): |
173 | | default=datetime.now) |
174 | | expires = Column(DateTime, nullable=False, |
175 | | default=lambda: datetime.now() + timedelta(minutes=5)) |
176 | | - code = Column(Unicode, index=True) |
177 | | + code = Column(Unicode, index=True, default=generate_code) |
178 | | |
179 | | user_id = Column(Integer, ForeignKey(User.id), nullable=False, |
180 | | index=True) |
181 | | @@ -150,6 +171,7 @@ class OAuthCode(Base): |
182 | | |
183 | | MODELS = [ |
184 | | OAuthToken, |
185 | | + OAuthRefreshToken, |
186 | | OAuthCode, |
187 | | OAuthClient, |
188 | | OAuthUserClient] |
189 | | diff --git a/mediagoblin/plugins/oauth/tools.py b/mediagoblin/plugins/oauth/tools.py |
190 | | index d21c8a5..25d0977 100644 |
191 | | --- a/mediagoblin/plugins/oauth/tools.py |
192 | | +++ b/mediagoblin/plugins/oauth/tools.py |
193 | | @@ -1,3 +1,4 @@ |
194 | | +# -*- coding: utf-8 -*- |
195 | | # GNU MediaGoblin -- federated, autonomous media hosting |
196 | | # Copyright (C) 2011, 2012 MediaGoblin contributors. See AUTHORS. |
197 | | # |
198 | | @@ -14,13 +15,25 @@ |
199 | | # You should have received a copy of the GNU Affero General Public License |
200 | | # along with this program. If not, see <http://www.gnu.org/licenses/>. |
201 | | |
202 | | +import uuid |
203 | | +import bcrypt |
204 | | + |
205 | | +from datetime import datetime |
206 | | + |
207 | | from functools import wraps |
208 | | |
209 | | -from mediagoblin.plugins.oauth.models import OAuthClient |
210 | | from mediagoblin.plugins.api.tools import json_response |
211 | | |
212 | | |
213 | | def require_client_auth(controller): |
214 | | + ''' |
215 | | + View decorator |
216 | | + |
217 | | + - Requires the presence of ``?client_id`` |
218 | | + ''' |
219 | | + # Avoid circular import |
220 | | + from mediagoblin.plugins.oauth.models import OAuthClient |
221 | | + |
222 | | @wraps(controller) |
223 | | def wrapper(request, *args, **kw): |
224 | | if not request.GET.get('client_id'): |
225 | | @@ -41,3 +54,63 @@ def require_client_auth(controller): |
226 | | return controller(request, client) |
227 | | |
228 | | return wrapper |
229 | | + |
230 | | + |
231 | | +def create_token(client, user): |
232 | | + ''' |
233 | | + Create an OAuthToken and an OAuthRefreshToken entry in the database |
234 | | + |
235 | | + Returns the data structure expected by the OAuth clients. |
236 | | + ''' |
237 | | + from mediagoblin.plugins.oauth.models import OAuthToken, OAuthRefreshToken |
238 | | + |
239 | | + token = OAuthToken() |
240 | | + token.user = user |
241 | | + token.client = client |
242 | | + token.save() |
243 | | + |
244 | | + refresh_token = OAuthRefreshToken() |
245 | | + refresh_token.user = user |
246 | | + refresh_token.client = client |
247 | | + refresh_token.save() |
248 | | + |
249 | | + # expire time of token in full seconds |
250 | | + # timedelta.total_seconds is python >= 2.7 or we would use that |
251 | | + td = token.expires - datetime.now() |
252 | | + exp_in = 86400*td.days + td.seconds # just ignore µsec |
253 | | + |
254 | | + return {'access_token': token.token, 'token_type': 'bearer', |
255 | | + 'refresh_token': refresh_token.token, 'expires_in': exp_in} |
256 | | + |
257 | | + |
258 | | +def generate_identifier(): |
259 | | + ''' Generates a ``uuid.uuid4()`` ''' |
260 | | + return unicode(uuid.uuid4()) |
261 | | + |
262 | | + |
263 | | +def generate_token(): |
264 | | + ''' Uses generate_identifier ''' |
265 | | + return generate_identifier() |
266 | | + |
267 | | + |
268 | | +def generate_refresh_token(): |
269 | | + ''' Uses generate_identifier ''' |
270 | | + return generate_identifier() |
271 | | + |
272 | | + |
273 | | +def generate_code(): |
274 | | + ''' Uses generate_identifier ''' |
275 | | + return generate_identifier() |
276 | | + |
277 | | + |
278 | | +def generate_secret(): |
279 | | + ''' |
280 | | + Generate a long string of pseudo-random characters |
281 | | + ''' |
282 | | + # XXX: We might not want it to use bcrypt, since bcrypt takes its time to |
283 | | + # generate the result. |
284 | | + return unicode( |
285 | | + bcrypt.hashpw( |
286 | | + unicode(uuid.uuid4()), |
287 | | + bcrypt.gensalt())) |
288 | | + |
289 | | diff --git a/mediagoblin/plugins/oauth/views.py b/mediagoblin/plugins/oauth/views.py |
290 | | index c7b2a33..ad8ea8f 100644 |
291 | | --- a/mediagoblin/plugins/oauth/views.py |
292 | | +++ b/mediagoblin/plugins/oauth/views.py |
293 | | @@ -16,21 +16,21 @@ |
294 | | # along with this program. If not, see <http://www.gnu.org/licenses/>. |
295 | | |
296 | | import logging |
297 | | -import json |
298 | | |
299 | | from urllib import urlencode |
300 | | -from uuid import uuid4 |
301 | | -from datetime import datetime |
302 | | + |
303 | | +from werkzeug.exceptions import BadRequest |
304 | | |
305 | | from mediagoblin.tools.response import render_to_response, redirect |
306 | | from mediagoblin.decorators import require_active_login |
307 | | -from mediagoblin.messages import add_message, SUCCESS, ERROR |
308 | | +from mediagoblin.messages import add_message, SUCCESS |
309 | | from mediagoblin.tools.translate import pass_to_ugettext as _ |
310 | | -from mediagoblin.plugins.oauth.models import OAuthCode, OAuthToken, \ |
311 | | - OAuthClient, OAuthUserClient |
312 | | +from mediagoblin.plugins.oauth.models import OAuthCode, OAuthClient, \ |
313 | | + OAuthUserClient, OAuthRefreshToken |
314 | | from mediagoblin.plugins.oauth.forms import ClientRegistrationForm, \ |
315 | | AuthorizationForm |
316 | | -from mediagoblin.plugins.oauth.tools import require_client_auth |
317 | | +from mediagoblin.plugins.oauth.tools import require_client_auth, \ |
318 | | + create_token |
319 | | from mediagoblin.plugins.api.tools import json_response |
320 | | |
321 | | _log = logging.getLogger(__name__) |
322 | | @@ -51,9 +51,6 @@ def register_client(request): |
323 | | client.owner_id = request.user.id |
324 | | client.redirect_uri = unicode(request.form['redirect_uri']) |
325 | | |
326 | | - client.generate_identifier() |
327 | | - client.generate_secret() |
328 | | - |
329 | | client.save() |
330 | | |
331 | | add_message(request, SUCCESS, _('The client {0} has been registered!')\ |
332 | | @@ -92,8 +89,8 @@ def authorize_client(request): |
333 | | form.client_id.data).first() |
334 | | |
335 | | if not client: |
336 | | - _log.error('''No such client id as received from client authorization |
337 | | - form.''') |
338 | | + _log.error('No such client id as received from client authorization \ |
339 | | +form.') |
340 | | return BadRequest() |
341 | | |
342 | | if form.validate(): |
343 | | @@ -105,7 +102,7 @@ def authorize_client(request): |
344 | | elif form.deny.data: |
345 | | relation.state = u'rejected' |
346 | | else: |
347 | | - return BadRequest |
348 | | + return BadRequest() |
349 | | |
350 | | relation.save() |
351 | | |
352 | | @@ -136,7 +133,7 @@ def authorize(request, client): |
353 | | return json_response({ |
354 | | 'status': 400, |
355 | | 'errors': |
356 | | - [u'Public clients MUST have a redirect_uri pre-set']}, |
357 | | + [u'Public clients should have a redirect_uri pre-set.']}, |
358 | | _disable_cors=True) |
359 | | |
360 | | redirect_uri = client.redirect_uri |
361 | | @@ -146,11 +143,10 @@ def authorize(request, client): |
362 | | if not redirect_uri: |
363 | | return json_response({ |
364 | | 'status': 400, |
365 | | - 'errors': [u'Can not find a redirect_uri for client: {0}'\ |
366 | | - .format(client.name)]}, _disable_cors=True) |
367 | | + 'errors': [u'No redirect_uri supplied!']}, |
368 | | + _disable_cors=True) |
369 | | |
370 | | code = OAuthCode() |
371 | | - code.code = unicode(uuid4()) |
372 | | code.user = request.user |
373 | | code.client = client |
374 | | code.save() |
375 | | @@ -180,59 +176,79 @@ def authorize(request, client): |
376 | | |
377 | | |
378 | | def access_token(request): |
379 | | + ''' |
380 | | + Access token endpoint provides access tokens to any clients that have the |
381 | | + right grants/credentials |
382 | | + ''' |
383 | | + |
384 | | + client = None |
385 | | + user = None |
386 | | + |
387 | | if request.GET.get('code'): |
388 | | + # Validate the code arg, then get the client object from the db. |
389 | | code = OAuthCode.query.filter(OAuthCode.code == |
390 | | request.GET.get('code')).first() |
391 | | |
392 | | - if code: |
393 | | - if code.client.type == u'confidential': |
394 | | - client_identifier = request.GET.get('client_id') |
395 | | - |
396 | | - if not client_identifier: |
397 | | - return json_response({ |
398 | | - 'error': 'invalid_request', |
399 | | - 'error_description': |
400 | | - 'Missing client_id in request'}) |
401 | | - |
402 | | - client_secret = request.GET.get('client_secret') |
403 | | - |
404 | | - if not client_secret: |
405 | | - return json_response({ |
406 | | - 'error': 'invalid_request', |
407 | | - 'error_description': |
408 | | - 'Missing client_secret in request'}) |
409 | | - |
410 | | - if not client_secret == code.client.secret or \ |
411 | | - not client_identifier == code.client.identifier: |
412 | | - return json_response({ |
413 | | - 'error': 'invalid_client', |
414 | | - 'error_description': |
415 | | - 'The client_id or client_secret does not match the' |
416 | | - ' code'}) |
417 | | - |
418 | | - token = OAuthToken() |
419 | | - token.token = unicode(uuid4()) |
420 | | - token.user = code.user |
421 | | - token.client = code.client |
422 | | - token.save() |
423 | | - |
424 | | - # expire time of token in full seconds |
425 | | - # timedelta.total_seconds is python >= 2.7 or we would use that |
426 | | - td = token.expires - datetime.now() |
427 | | - exp_in = 86400*td.days + td.seconds # just ignore µsec |
428 | | - |
429 | | - access_token_data = { |
430 | | - 'access_token': token.token, |
431 | | - 'token_type': 'bearer', |
432 | | - 'expires_in': exp_in} |
433 | | - return json_response(access_token_data, _disable_cors=True) |
434 | | - else: |
435 | | + if not code: |
436 | | return json_response({ |
437 | | 'error': 'invalid_request', |
438 | | 'error_description': |
439 | | - 'Invalid code'}) |
440 | | - else: |
441 | | - return json_response({ |
442 | | - 'error': 'invalid_request', |
443 | | - 'error_descriptin': |
444 | | - 'Missing `code` parameter in request'}) |
445 | | + 'Invalid code.'}) |
446 | | + |
447 | | + client = code.client |
448 | | + user = code.user |
449 | | + |
450 | | + elif request.args.get('refresh_token'): |
451 | | + # Validate a refresh token, then get the client object from the db. |
452 | | + refresh_token = OAuthRefreshToken.query.filter( |
453 | | + OAuthRefreshToken.token == |
454 | | + request.args.get('refresh_token')).first() |
455 | | + |
456 | | + if not refresh_token: |
457 | | + return json_response({ |
458 | | + 'error': 'invalid_request', |
459 | | + 'error_description': |
460 | | + 'Invalid refresh token.'}) |
461 | | + |
462 | | + client = refresh_token.client |
463 | | + user = refresh_token.user |
464 | | + |
465 | | + if client: |
466 | | + client_identifier = request.GET.get('client_id') |
467 | | + |
468 | | + if not client_identifier: |
469 | | + return json_response({ |
470 | | + 'error': 'invalid_request', |
471 | | + 'error_description': |
472 | | + 'Missing client_id in request.'}) |
473 | | + |
474 | | + if not client_identifier == client.identifier: |
475 | | + return json_response({ |
476 | | + 'error': 'invalid_client', |
477 | | + 'error_description': |
478 | | + 'Mismatching client credentials.'}) |
479 | | + |
480 | | + if client.type == u'confidential': |
481 | | + client_secret = request.GET.get('client_secret') |
482 | | + |
483 | | + if not client_secret: |
484 | | + return json_response({ |
485 | | + 'error': 'invalid_request', |
486 | | + 'error_description': |
487 | | + 'Missing client_secret in request.'}) |
488 | | + |
489 | | + if not client_secret == client.secret: |
490 | | + return json_response({ |
491 | | + 'error': 'invalid_client', |
492 | | + 'error_description': |
493 | | + 'Mismatching client credentials.'}) |
494 | | + |
495 | | + |
496 | | + access_token_data = create_token(client, user) |
497 | | + |
498 | | + return json_response(access_token_data, _disable_cors=True) |
499 | | + |
500 | | + return json_response({ |
501 | | + 'error': 'invalid_request', |
502 | | + 'error_description': |
503 | | + 'Missing `code` or `refresh_token` parameter in request.'}) |
504 | | diff --git a/mediagoblin/tests/test_oauth.py b/mediagoblin/tests/test_oauth.py |
505 | | index 94ba5da..f036569 100644 |
506 | | --- a/mediagoblin/tests/test_oauth.py |
507 | | +++ b/mediagoblin/tests/test_oauth.py |
508 | | @@ -16,12 +16,13 @@ |
509 | | |
510 | | import json |
511 | | import logging |
512 | | +import urllib |
513 | | |
514 | | from urlparse import parse_qs, urlparse |
515 | | |
516 | | from mediagoblin import mg_globals |
517 | | from mediagoblin.tools import template, pluginapi |
518 | | -from mediagoblin.tests.tools import get_app, fixture_add_user |
519 | | +from mediagoblin.tests.tools import get_app, fixture_add_user, expect_failure |
520 | | |
521 | | |
522 | | _log = logging.getLogger(__name__) |
523 | | @@ -70,7 +71,7 @@ class TestOAuth(object): |
524 | | assert response.status_int == 200 |
525 | | |
526 | | # Should display an error |
527 | | - assert ctx['form'].redirect_uri.errors |
528 | | + assert len(ctx['form'].redirect_uri.errors) |
529 | | |
530 | | # Should not pass through |
531 | | assert not client |
532 | | @@ -78,12 +79,16 @@ class TestOAuth(object): |
533 | | def test_2_successful_public_client_registration(self): |
534 | | ''' Successfully register a public client ''' |
535 | | self.login() |
536 | | + uri = 'http://foo.example' |
537 | | self.register_client(u'OMGOMG', 'public', 'OMG!', |
538 | | - 'http://foo.example') |
539 | | + uri) |
540 | | |
541 | | client = self.db.OAuthClient.query.filter( |
542 | | self.db.OAuthClient.name == u'OMGOMG').first() |
543 | | |
544 | | + # redirect_uri should be set |
545 | | + assert client.redirect_uri == uri |
546 | | + |
547 | | # Client should have been registered |
548 | | assert client |
549 | | |
550 | | @@ -111,7 +116,7 @@ class TestOAuth(object): |
551 | | redirect_uri = 'https://foo.example' |
552 | | response = self.app.get('/oauth/authorize', { |
553 | | 'client_id': client.identifier, |
554 | | - 'scope': 'admin', |
555 | | + 'scope': 'all', |
556 | | 'redirect_uri': redirect_uri}) |
557 | | |
558 | | # User-agent should NOT be redirected |
559 | | @@ -137,6 +142,7 @@ class TestOAuth(object): |
560 | | return authorization_response, client_identifier |
561 | | |
562 | | def get_code_from_redirect_uri(self, uri): |
563 | | + ''' Get the value of ?code= from an URI ''' |
564 | | return parse_qs(urlparse(uri).query)['code'][0] |
565 | | |
566 | | def test_token_endpoint_successful_confidential_request(self): |
567 | | @@ -162,6 +168,11 @@ code={1}&client_secret={2}'.format(client_id, code, client.secret)) |
568 | | assert type(token_data['expires_in']) == int |
569 | | assert token_data['expires_in'] > 0 |
570 | | |
571 | | + # There should be a refresh token provided in the token data |
572 | | + assert len(token_data['refresh_token']) |
573 | | + |
574 | | + return client_id, token_data |
575 | | + |
576 | | def test_token_endpont_missing_id_confidential_request(self): |
577 | | ''' Unsuccessful request against token endpoint, missing client_id ''' |
578 | | code_redirect, client_id = self.test_4_authorize_confidential_client() |
579 | | @@ -181,4 +192,30 @@ code={0}&client_secret={1}'.format(code, client.secret)) |
580 | | assert 'error' in token_data |
581 | | assert not 'access_token' in token_data |
582 | | assert token_data['error'] == 'invalid_request' |
583 | | - assert token_data['error_description'] == 'Missing client_id in request' |
584 | | + assert len(token_data['error_description']) |
585 | | + |
586 | | + def test_refresh_token(self): |
587 | | + ''' Try to get a new access token using the refresh token ''' |
588 | | + # Get an access token and a refresh token |
589 | | + client_id, token_data =\ |
590 | | + self.test_token_endpoint_successful_confidential_request() |
591 | | + |
592 | | + client = self.db.OAuthClient.query.filter( |
593 | | + self.db.OAuthClient.identifier == client_id).first() |
594 | | + |
595 | | + token_res = self.app.get('/oauth/access_token', |
596 | | + {'refresh_token': token_data['refresh_token'], |
597 | | + 'client_id': client_id, |
598 | | + 'client_secret': client.secret |
599 | | + }) |
600 | | + |
601 | | + assert token_res.status_int == 200 |
602 | | + |
603 | | + new_token_data = json.loads(token_res.body) |
604 | | + |
605 | | + assert not 'error' in new_token_data |
606 | | + assert 'access_token' in new_token_data |
607 | | + assert 'token_type' in new_token_data |
608 | | + assert 'expires_in' in new_token_data |
609 | | + assert type(new_token_data['expires_in']) == int |
610 | | + assert new_token_data['expires_in'] > 0 |
611 | | diff --git a/mediagoblin/tests/tools.py b/mediagoblin/tests/tools.py |
612 | | index cc4a7ad..71954e0 100644 |
613 | | --- a/mediagoblin/tests/tools.py |
614 | | +++ b/mediagoblin/tests/tools.py |
615 | | @@ -19,6 +19,8 @@ import os |
616 | | import pkg_resources |
617 | | import shutil |
618 | | |
619 | | +import nose |
620 | | + |
621 | | from functools import wraps |
622 | | |
623 | | from paste.deploy import loadapp |
624 | | @@ -235,7 +237,7 @@ def fixture_media_entry(title=u"Some title", slug=None, |
625 | | entry.slug = slug |
626 | | entry.uploader = uploader or fixture_add_user().id |
627 | | entry.media_type = u'image' |
628 | | - |
629 | | + |
630 | | if gen_slug: |
631 | | entry.generate_slug() |
632 | | if save: |
633 | | @@ -263,3 +265,16 @@ def fixture_add_collection(name=u"My first Collection", user=None): |
634 | | Session.expunge(coll) |
635 | | |
636 | | return coll |
637 | | + |
638 | | + |
639 | | +def expect_failure(test): |
640 | | + @wraps(test) |
641 | | + def wrapper(*args, **kwargs): |
642 | | + try: |
643 | | + test(*args, **kwargs) |
644 | | + except Exception: |
645 | | + raise nose.SkipTest |
646 | | + else: |
647 | | + raise AssertionError( |
648 | | + 'Test is expected to fail, did you add a feature?') |
649 | | + return wrapper |
650 | | }}} |
| 3 | ''Edited to remove really long diff.'' |