Changes between Initial Version and Version 1 of Ticket #548, comment 1


Ignore:
Timestamp:
Apr 5, 2013, 9:21:22 PM (11 years ago)
Author:
joar

Legend:

Unmodified
Added
Removed
Modified
  • Ticket #548, comment 1

    initial v1  
    11I've pushed a fix for this in the {{{oauth/refresh_tokens}}} branch at {{{git@github.com:joar/mediagoblin.git}}}
    22
    3 {{{#!diff
    4 diff --git a/mediagoblin/plugins/oauth/__init__.py b/mediagoblin/plugins/oauth/__init__.py
    5 index 4714d95..5762379 100644
    6 --- a/mediagoblin/plugins/oauth/__init__.py
    7 +++ b/mediagoblin/plugins/oauth/__init__.py
    8 @@ -34,7 +34,7 @@ def setup_plugin():
    9      _log.debug('OAuth config: {0}'.format(config))
    10  
    11      routes = [
    12 -       ('mediagoblin.plugins.oauth.authorize',
    13 +        ('mediagoblin.plugins.oauth.authorize',
    14              '/oauth/authorize',
    15              'mediagoblin.plugins.oauth.views:authorize'),
    16          ('mediagoblin.plugins.oauth.authorize_client',
    17 diff --git a/mediagoblin/plugins/oauth/migrations.py b/mediagoblin/plugins/oauth/migrations.py
    18 index 6aa0d7c..f70a2e8 100644
    19 --- a/mediagoblin/plugins/oauth/migrations.py
    20 +++ b/mediagoblin/plugins/oauth/migrations.py
    21 @@ -102,6 +102,22 @@ class OAuthCode_v0(declarative_base()):
    22      client_id = Column(Integer, ForeignKey(OAuthClient_v0.id), nullable=False)
    23  
    24  
    25 +class OAuthRefreshToken_v0(declarative_base()):
    26 +    __tablename__ = 'oauth__refresh_tokens'
    27 +
    28 +    id = Column(Integer, primary_key=True)
    29 +    created = Column(DateTime, nullable=False,
    30 +                     default=datetime.now)
    31 +
    32 +    token = Column(Unicode, index=True)
    33 +
    34 +    user_id = Column(Integer, ForeignKey(User.id), nullable=False,
    35 +            index=True)
    36 +
    37 +    # XXX: Is it OK to use OAuthClient_v0.id in this way?
    38 +    client_id = Column(Integer, ForeignKey(OAuthClient_v0.id), nullable=False)
    39 +
    40 +
    41  @RegisterMigration(1, MIGRATIONS)
    42  def remove_and_replace_token_and_code(db):
    43      metadata = MetaData(bind=db.bind)
    44 @@ -122,3 +138,22 @@ def remove_and_replace_token_and_code(db):
    45      OAuthCode_v0.__table__.create(db.bind)
    46  
    47      db.commit()
    48 +
    49 +
    50 +@RegisterMigration(2, MIGRATIONS)
    51 +def remove_refresh_token_field(db):
    52 +    metadata = MetaData(bind=db.bind)
    53 +
    54 +    token_table = Table('oauth__tokens', metadata, autoload=True,
    55 +                        autoload_with=db.bind)
    56 +
    57 +    refresh_token = token_table.columns['refresh_token']
    58 +
    59 +    refresh_token.drop()
    60 +    db.commit()
    61 +
    62 +@RegisterMigration(3, MIGRATIONS)
    63 +def create_refresh_token_table(db):
    64 +    OAuthRefreshToken_v0.__table__.create(db.bind)
    65 +
    66 +    db.commit()
    67 diff --git a/mediagoblin/plugins/oauth/models.py b/mediagoblin/plugins/oauth/models.py
    68 index 695dad3..28735dd 100644
    69 --- a/mediagoblin/plugins/oauth/models.py
    70 +++ b/mediagoblin/plugins/oauth/models.py
    71 @@ -19,12 +19,14 @@ import bcrypt
    72  
    73  from datetime import datetime, timedelta
    74  
    75 -from mediagoblin.db.base import Base
    76 -from mediagoblin.db.models import User
    77  
    78  from sqlalchemy import (
    79          Column, Unicode, Integer, DateTime, ForeignKey, Enum)
    80  from sqlalchemy.orm import relationship
    81 +from mediagoblin.db.base import Base
    82 +from mediagoblin.db.models import User
    83 +from mediagoblin.plugins.oauth.tools import generate_identifier, \
    84 +    generate_secret, generate_token, generate_code, generate_refresh_token
    85  
    86  # Don't remove this, I *think* it applies sqlalchemy-migrate functionality onto
    87  # the models.
    88 @@ -41,8 +43,9 @@ class OAuthClient(Base):
    89      name = Column(Unicode)
    90      description = Column(Unicode)
    91  
    92 -    identifier = Column(Unicode, unique=True, index=True)
    93 -    secret = Column(Unicode, index=True)
    94 +    identifier = Column(Unicode, unique=True, index=True,
    95 +                        default=generate_identifier)
    96 +    secret = Column(Unicode, index=True, default=generate_secret)
    97  
    98      owner_id = Column(Integer, ForeignKey(User.id))
    99      owner = relationship(User, backref='registered_clients')
    100 @@ -54,14 +57,8 @@ class OAuthClient(Base):
    101          u'public',
    102          name=u'oauth__client_type'))
    103  
    104 -    def generate_identifier(self):
    105 -        self.identifier = unicode(uuid.uuid4())
    106 -
    107 -    def generate_secret(self):
    108 -        self.secret = unicode(
    109 -                bcrypt.hashpw(
    110 -                    unicode(uuid.uuid4()),
    111 -                    bcrypt.gensalt()))
    112 +    def update_secret(self):
    113 +        self.secret = generate_secret()
    114  
    115      def __repr__(self):
    116          return '<{0} {1}:{2} ({3})>'.format(
    117 @@ -76,10 +73,10 @@ class OAuthUserClient(Base):
    118      id = Column(Integer, primary_key=True)
    119  
    120      user_id = Column(Integer, ForeignKey(User.id))
    121 -    user = relationship(User, backref='oauth_clients')
    122 +    user = relationship(User, backref='oauth_client_relations')
    123  
    124      client_id = Column(Integer, ForeignKey(OAuthClient.id))
    125 -    client = relationship(OAuthClient, backref='users')
    126 +    client = relationship(OAuthClient, backref='oauth_user_relations')
    127  
    128      state = Column(Enum(
    129          u'approved',
    130 @@ -103,8 +100,7 @@ class OAuthToken(Base):
    131              default=datetime.now)
    132      expires = Column(DateTime, nullable=False,
    133              default=lambda: datetime.now() + timedelta(days=30))
    134 -    token = Column(Unicode, index=True)
    135 -    refresh_token = Column(Unicode, index=True)
    136 +    token = Column(Unicode, index=True, default=generate_token)
    137  
    138      user_id = Column(Integer, ForeignKey(User.id), nullable=False,
    139              index=True)
    140 @@ -121,6 +117,31 @@ class OAuthToken(Base):
    141                  self.user,
    142                  self.client)
    143  
    144 +class OAuthRefreshToken(Base):
    145 +    __tablename__ = 'oauth__refresh_tokens'
    146 +
    147 +    id = Column(Integer, primary_key=True)
    148 +    created = Column(DateTime, nullable=False,
    149 +                     default=datetime.now)
    150 +
    151 +    token = Column(Unicode, index=True,
    152 +                   default=generate_refresh_token)
    153 +
    154 +    user_id = Column(Integer, ForeignKey(User.id), nullable=False,
    155 +            index=True)
    156 +
    157 +    user = relationship(User)
    158 +
    159 +    client_id = Column(Integer, ForeignKey(OAuthClient.id), nullable=False)
    160 +    client = relationship(OAuthClient)
    161 +
    162 +    def __repr__(self):
    163 +        return '<{0} #{1} [{3}, {4}]>'.format(
    164 +                self.__class__.__name__,
    165 +                self.id,
    166 +                self.user,
    167 +                self.client)
    168 +
    169  
    170  class OAuthCode(Base):
    171      __tablename__ = 'oauth__codes'
    172 @@ -130,7 +151,7 @@ class OAuthCode(Base):
    173              default=datetime.now)
    174      expires = Column(DateTime, nullable=False,
    175              default=lambda: datetime.now() + timedelta(minutes=5))
    176 -    code = Column(Unicode, index=True)
    177 +    code = Column(Unicode, index=True, default=generate_code)
    178  
    179      user_id = Column(Integer, ForeignKey(User.id), nullable=False,
    180              index=True)
    181 @@ -150,6 +171,7 @@ class OAuthCode(Base):
    182  
    183  MODELS = [
    184          OAuthToken,
    185 +        OAuthRefreshToken,
    186          OAuthCode,
    187          OAuthClient,
    188          OAuthUserClient]
    189 diff --git a/mediagoblin/plugins/oauth/tools.py b/mediagoblin/plugins/oauth/tools.py
    190 index d21c8a5..25d0977 100644
    191 --- a/mediagoblin/plugins/oauth/tools.py
    192 +++ b/mediagoblin/plugins/oauth/tools.py
    193 @@ -1,3 +1,4 @@
    194 +# -*- coding: utf-8 -*-
    195  # GNU MediaGoblin -- federated, autonomous media hosting
    196  # Copyright (C) 2011, 2012 MediaGoblin contributors.  See AUTHORS.
    197  #
    198 @@ -14,13 +15,25 @@
    199  # You should have received a copy of the GNU Affero General Public License
    200  # along with this program.  If not, see <http://www.gnu.org/licenses/>.
    201  
    202 +import uuid
    203 +import bcrypt
    204 +
    205 +from datetime import datetime
    206 +
    207  from functools import wraps
    208  
    209 -from mediagoblin.plugins.oauth.models import OAuthClient
    210  from mediagoblin.plugins.api.tools import json_response
    211  
    212  
    213  def require_client_auth(controller):
    214 +    '''
    215 +    View decorator
    216 +
    217 +    - Requires the presence of ``?client_id``
    218 +    '''
    219 +    # Avoid circular import
    220 +    from mediagoblin.plugins.oauth.models import OAuthClient
    221 +
    222      @wraps(controller)
    223      def wrapper(request, *args, **kw):
    224          if not request.GET.get('client_id'):
    225 @@ -41,3 +54,63 @@ def require_client_auth(controller):
    226          return controller(request, client)
    227  
    228      return wrapper
    229 +
    230 +
    231 +def create_token(client, user):
    232 +    '''
    233 +    Create an OAuthToken and an OAuthRefreshToken entry in the database
    234 +
    235 +    Returns the data structure expected by the OAuth clients.
    236 +    '''
    237 +    from mediagoblin.plugins.oauth.models import OAuthToken, OAuthRefreshToken
    238 +
    239 +    token = OAuthToken()
    240 +    token.user = user
    241 +    token.client = client
    242 +    token.save()
    243 +
    244 +    refresh_token = OAuthRefreshToken()
    245 +    refresh_token.user = user
    246 +    refresh_token.client = client
    247 +    refresh_token.save()
    248 +
    249 +    # expire time of token in full seconds
    250 +    # timedelta.total_seconds is python >= 2.7 or we would use that
    251 +    td = token.expires - datetime.now()
    252 +    exp_in = 86400*td.days + td.seconds # just ignore µsec
    253 +
    254 +    return {'access_token': token.token, 'token_type': 'bearer',
    255 +            'refresh_token': refresh_token.token, 'expires_in': exp_in}
    256 +
    257 +
    258 +def generate_identifier():
    259 +    ''' Generates a ``uuid.uuid4()`` '''
    260 +    return unicode(uuid.uuid4())
    261 +
    262 +
    263 +def generate_token():
    264 +    ''' Uses generate_identifier '''
    265 +    return generate_identifier()
    266 +
    267 +
    268 +def generate_refresh_token():
    269 +    ''' Uses generate_identifier '''
    270 +    return generate_identifier()
    271 +
    272 +
    273 +def generate_code():
    274 +    ''' Uses generate_identifier '''
    275 +    return generate_identifier()
    276 +
    277 +
    278 +def generate_secret():
    279 +    '''
    280 +    Generate a long string of pseudo-random characters
    281 +    '''
    282 +    # XXX: We might not want it to use bcrypt, since bcrypt takes its time to
    283 +    # generate the result.
    284 +    return unicode(
    285 +            bcrypt.hashpw(
    286 +                unicode(uuid.uuid4()),
    287 +                bcrypt.gensalt()))
    288 +
    289 diff --git a/mediagoblin/plugins/oauth/views.py b/mediagoblin/plugins/oauth/views.py
    290 index c7b2a33..ad8ea8f 100644
    291 --- a/mediagoblin/plugins/oauth/views.py
    292 +++ b/mediagoblin/plugins/oauth/views.py
    293 @@ -16,21 +16,21 @@
    294  # along with this program.  If not, see <http://www.gnu.org/licenses/>.
    295  
    296  import logging
    297 -import json
    298  
    299  from urllib import urlencode
    300 -from uuid import uuid4
    301 -from datetime import datetime
    302 +
    303 +from werkzeug.exceptions import BadRequest
    304  
    305  from mediagoblin.tools.response import render_to_response, redirect
    306  from mediagoblin.decorators import require_active_login
    307 -from mediagoblin.messages import add_message, SUCCESS, ERROR
    308 +from mediagoblin.messages import add_message, SUCCESS
    309  from mediagoblin.tools.translate import pass_to_ugettext as _
    310 -from mediagoblin.plugins.oauth.models import OAuthCode, OAuthToken, \
    311 -        OAuthClient, OAuthUserClient
    312 +from mediagoblin.plugins.oauth.models import OAuthCode, OAuthClient, \
    313 +        OAuthUserClient, OAuthRefreshToken
    314  from mediagoblin.plugins.oauth.forms import ClientRegistrationForm, \
    315          AuthorizationForm
    316 -from mediagoblin.plugins.oauth.tools import require_client_auth
    317 +from mediagoblin.plugins.oauth.tools import require_client_auth, \
    318 +        create_token
    319  from mediagoblin.plugins.api.tools import json_response
    320  
    321  _log = logging.getLogger(__name__)
    322 @@ -51,9 +51,6 @@ def register_client(request):
    323          client.owner_id = request.user.id
    324          client.redirect_uri = unicode(request.form['redirect_uri'])
    325  
    326 -        client.generate_identifier()
    327 -        client.generate_secret()
    328 -
    329          client.save()
    330  
    331          add_message(request, SUCCESS, _('The client {0} has been registered!')\
    332 @@ -92,8 +89,8 @@ def authorize_client(request):
    333          form.client_id.data).first()
    334  
    335      if not client:
    336 -        _log.error('''No such client id as received from client authorization
    337 -                form.''')
    338 +        _log.error('No such client id as received from client authorization \
    339 +form.')
    340          return BadRequest()
    341  
    342      if form.validate():
    343 @@ -105,7 +102,7 @@ def authorize_client(request):
    344          elif form.deny.data:
    345              relation.state = u'rejected'
    346          else:
    347 -            return BadRequest
    348 +            return BadRequest()
    349  
    350          relation.save()
    351  
    352 @@ -136,7 +133,7 @@ def authorize(request, client):
    353                  return json_response({
    354                      'status': 400,
    355                      'errors':
    356 -                        [u'Public clients MUST have a redirect_uri pre-set']},
    357 +                        [u'Public clients should have a redirect_uri pre-set.']},
    358                          _disable_cors=True)
    359  
    360              redirect_uri = client.redirect_uri
    361 @@ -146,11 +143,10 @@ def authorize(request, client):
    362              if not redirect_uri:
    363                  return json_response({
    364                      'status': 400,
    365 -                    'errors': [u'Can not find a redirect_uri for client: {0}'\
    366 -                            .format(client.name)]}, _disable_cors=True)
    367 +                    'errors': [u'No redirect_uri supplied!']},
    368 +                    _disable_cors=True)
    369  
    370          code = OAuthCode()
    371 -        code.code = unicode(uuid4())
    372          code.user = request.user
    373          code.client = client
    374          code.save()
    375 @@ -180,59 +176,79 @@ def authorize(request, client):
    376  
    377  
    378  def access_token(request):
    379 +    '''
    380 +    Access token endpoint provides access tokens to any clients that have the
    381 +    right grants/credentials
    382 +    '''
    383 +
    384 +    client = None
    385 +    user = None
    386 +
    387      if request.GET.get('code'):
    388 +        # Validate the code arg, then get the client object from the db.
    389          code = OAuthCode.query.filter(OAuthCode.code ==
    390                  request.GET.get('code')).first()
    391  
    392 -        if code:
    393 -            if code.client.type == u'confidential':
    394 -                client_identifier = request.GET.get('client_id')
    395 -
    396 -                if not client_identifier:
    397 -                    return json_response({
    398 -                        'error': 'invalid_request',
    399 -                        'error_description':
    400 -                            'Missing client_id in request'})
    401 -
    402 -                client_secret = request.GET.get('client_secret')
    403 -
    404 -                if not client_secret:
    405 -                    return json_response({
    406 -                        'error': 'invalid_request',
    407 -                        'error_description':
    408 -                            'Missing client_secret in request'})
    409 -
    410 -                if not client_secret == code.client.secret or \
    411 -                        not client_identifier == code.client.identifier:
    412 -                    return json_response({
    413 -                        'error': 'invalid_client',
    414 -                        'error_description':
    415 -                            'The client_id or client_secret does not match the'
    416 -                            ' code'})
    417 -
    418 -            token = OAuthToken()
    419 -            token.token = unicode(uuid4())
    420 -            token.user = code.user
    421 -            token.client = code.client
    422 -            token.save()
    423 -
    424 -            # expire time of token in full seconds
    425 -            # timedelta.total_seconds is python >= 2.7 or we would use that
    426 -            td = token.expires - datetime.now()
    427 -            exp_in = 86400*td.days + td.seconds # just ignore µsec
    428 -
    429 -            access_token_data = {
    430 -                'access_token': token.token,
    431 -                'token_type': 'bearer',
    432 -                'expires_in': exp_in}
    433 -            return json_response(access_token_data, _disable_cors=True)
    434 -        else:
    435 +        if not code:
    436              return json_response({
    437                  'error': 'invalid_request',
    438                  'error_description':
    439 -                    'Invalid code'})
    440 -    else:
    441 -        return json_response({
    442 -            'error': 'invalid_request',
    443 -            'error_descriptin':
    444 -                'Missing `code` parameter in request'})
    445 +                    'Invalid code.'})
    446 +
    447 +        client = code.client
    448 +        user = code.user
    449 +
    450 +    elif request.args.get('refresh_token'):
    451 +        # Validate a refresh token, then get the client object from the db.
    452 +        refresh_token = OAuthRefreshToken.query.filter(
    453 +            OAuthRefreshToken.token ==
    454 +            request.args.get('refresh_token')).first()
    455 +
    456 +        if not refresh_token:
    457 +            return json_response({
    458 +                'error': 'invalid_request',
    459 +                'error_description':
    460 +                    'Invalid refresh token.'})
    461 +
    462 +        client = refresh_token.client
    463 +        user = refresh_token.user
    464 +
    465 +    if client:
    466 +        client_identifier = request.GET.get('client_id')
    467 +
    468 +        if not client_identifier:
    469 +            return json_response({
    470 +                'error': 'invalid_request',
    471 +                'error_description':
    472 +                    'Missing client_id in request.'})
    473 +
    474 +        if not client_identifier == client.identifier:
    475 +            return json_response({
    476 +                'error': 'invalid_client',
    477 +                'error_description':
    478 +                    'Mismatching client credentials.'})
    479 +
    480 +        if client.type == u'confidential':
    481 +            client_secret = request.GET.get('client_secret')
    482 +
    483 +            if not client_secret:
    484 +                return json_response({
    485 +                    'error': 'invalid_request',
    486 +                    'error_description':
    487 +                        'Missing client_secret in request.'})
    488 +
    489 +            if not client_secret == client.secret:
    490 +                return json_response({
    491 +                    'error': 'invalid_client',
    492 +                    'error_description':
    493 +                        'Mismatching client credentials.'})
    494 +
    495 +
    496 +        access_token_data = create_token(client, user)
    497 +
    498 +        return json_response(access_token_data, _disable_cors=True)
    499 +
    500 +    return json_response({
    501 +        'error': 'invalid_request',
    502 +        'error_description':
    503 +            'Missing `code` or `refresh_token` parameter in request.'})
    504 diff --git a/mediagoblin/tests/test_oauth.py b/mediagoblin/tests/test_oauth.py
    505 index 94ba5da..f036569 100644
    506 --- a/mediagoblin/tests/test_oauth.py
    507 +++ b/mediagoblin/tests/test_oauth.py
    508 @@ -16,12 +16,13 @@
    509  
    510  import json
    511  import logging
    512 +import urllib
    513  
    514  from urlparse import parse_qs, urlparse
    515  
    516  from mediagoblin import mg_globals
    517  from mediagoblin.tools import template, pluginapi
    518 -from mediagoblin.tests.tools import get_app, fixture_add_user
    519 +from mediagoblin.tests.tools import get_app, fixture_add_user, expect_failure
    520  
    521  
    522  _log = logging.getLogger(__name__)
    523 @@ -70,7 +71,7 @@ class TestOAuth(object):
    524          assert response.status_int == 200
    525  
    526          # Should display an error
    527 -        assert ctx['form'].redirect_uri.errors
    528 +        assert len(ctx['form'].redirect_uri.errors)
    529  
    530          # Should not pass through
    531          assert not client
    532 @@ -78,12 +79,16 @@ class TestOAuth(object):
    533      def test_2_successful_public_client_registration(self):
    534          ''' Successfully register a public client '''
    535          self.login()
    536 +        uri = 'http://foo.example'
    537          self.register_client(u'OMGOMG', 'public', 'OMG!',
    538 -                'http://foo.example')
    539 +                uri)
    540  
    541          client = self.db.OAuthClient.query.filter(
    542                  self.db.OAuthClient.name == u'OMGOMG').first()
    543  
    544 +        # redirect_uri should be set
    545 +        assert client.redirect_uri == uri
    546 +
    547          # Client should have been registered
    548          assert client
    549  
    550 @@ -111,7 +116,7 @@ class TestOAuth(object):
    551          redirect_uri = 'https://foo.example'
    552          response = self.app.get('/oauth/authorize', {
    553                  'client_id': client.identifier,
    554 -                'scope': 'admin',
    555 +                'scope': 'all',
    556                  'redirect_uri': redirect_uri})
    557  
    558          # User-agent should NOT be redirected
    559 @@ -137,6 +142,7 @@ class TestOAuth(object):
    560          return authorization_response, client_identifier
    561  
    562      def get_code_from_redirect_uri(self, uri):
    563 +        ''' Get the value of ?code= from an URI '''
    564          return parse_qs(urlparse(uri).query)['code'][0]
    565  
    566      def test_token_endpoint_successful_confidential_request(self):
    567 @@ -162,6 +168,11 @@ code={1}&client_secret={2}'.format(client_id, code, client.secret))
    568          assert type(token_data['expires_in']) == int
    569          assert token_data['expires_in'] > 0
    570  
    571 +        # There should be a refresh token provided in the token data
    572 +        assert len(token_data['refresh_token'])
    573 +
    574 +        return client_id, token_data
    575 +
    576      def test_token_endpont_missing_id_confidential_request(self):
    577          ''' Unsuccessful request against token endpoint, missing client_id '''
    578          code_redirect, client_id = self.test_4_authorize_confidential_client()
    579 @@ -181,4 +192,30 @@ code={0}&client_secret={1}'.format(code, client.secret))
    580          assert 'error' in token_data
    581          assert not 'access_token' in token_data
    582          assert token_data['error'] == 'invalid_request'
    583 -        assert token_data['error_description'] == 'Missing client_id in request'
    584 +        assert len(token_data['error_description'])
    585 +
    586 +    def test_refresh_token(self):
    587 +        ''' Try to get a new access token using the refresh token '''
    588 +        # Get an access token and a refresh token
    589 +        client_id, token_data =\
    590 +            self.test_token_endpoint_successful_confidential_request()
    591 +
    592 +        client = self.db.OAuthClient.query.filter(
    593 +            self.db.OAuthClient.identifier == client_id).first()
    594 +
    595 +        token_res = self.app.get('/oauth/access_token',
    596 +                     {'refresh_token': token_data['refresh_token'],
    597 +                      'client_id': client_id,
    598 +                      'client_secret': client.secret
    599 +                      })
    600 +
    601 +        assert token_res.status_int == 200
    602 +
    603 +        new_token_data = json.loads(token_res.body)
    604 +
    605 +        assert not 'error' in new_token_data
    606 +        assert 'access_token' in new_token_data
    607 +        assert 'token_type' in new_token_data
    608 +        assert 'expires_in' in new_token_data
    609 +        assert type(new_token_data['expires_in']) == int
    610 +        assert new_token_data['expires_in'] > 0
    611 diff --git a/mediagoblin/tests/tools.py b/mediagoblin/tests/tools.py
    612 index cc4a7ad..71954e0 100644
    613 --- a/mediagoblin/tests/tools.py
    614 +++ b/mediagoblin/tests/tools.py
    615 @@ -19,6 +19,8 @@ import os
    616  import pkg_resources
    617  import shutil
    618  
    619 +import nose
    620 +
    621  from functools import wraps
    622  
    623  from paste.deploy import loadapp
    624 @@ -235,7 +237,7 @@ def fixture_media_entry(title=u"Some title", slug=None,
    625      entry.slug = slug
    626      entry.uploader = uploader or fixture_add_user().id
    627      entry.media_type = u'image'
    628 -   
    629 +
    630      if gen_slug:
    631          entry.generate_slug()
    632      if save:
    633 @@ -263,3 +265,16 @@ def fixture_add_collection(name=u"My first Collection", user=None):
    634      Session.expunge(coll)
    635  
    636      return coll
    637 +
    638 +
    639 +def expect_failure(test):
    640 +    @wraps(test)
    641 +    def wrapper(*args, **kwargs):
    642 +        try:
    643 +            test(*args, **kwargs)
    644 +        except Exception:
    645 +            raise nose.SkipTest
    646 +        else:
    647 +            raise AssertionError(
    648 +                'Test is expected to fail, did you add a feature?')
    649 +    return wrapper
    650 }}}
     3''Edited to remove really long diff.''