diff --git a/rhodecode/model/scm.py b/rhodecode/model/scm.py --- a/rhodecode/model/scm.py +++ b/rhodecode/model/scm.py @@ -28,22 +28,20 @@ import traceback import logging import cStringIO -from sqlalchemy.exc import DatabaseError - -from vcs import get_backend -from vcs.exceptions import RepositoryError -from vcs.utils.lazy import LazyProperty -from vcs.nodes import FileNode +from rhodecode.lib.vcs import get_backend +from rhodecode.lib.vcs.exceptions import RepositoryError +from rhodecode.lib.vcs.utils.lazy import LazyProperty +from rhodecode.lib.vcs.nodes import FileNode from rhodecode import BACKENDS from rhodecode.lib import helpers as h from rhodecode.lib import safe_str -from rhodecode.lib.auth import HasRepoPermissionAny +from rhodecode.lib.auth import HasRepoPermissionAny, HasReposGroupPermissionAny from rhodecode.lib.utils import get_repos as get_filesystem_repos, make_ui, \ action_logger, EmptyChangeset from rhodecode.model import BaseModel from rhodecode.model.db import Repository, RhodeCodeUi, CacheInvalidation, \ - UserFollowing, UserLog, User + UserFollowing, UserLog, User, RepoGroup log = logging.getLogger(__name__) @@ -63,6 +61,7 @@ class RepoTemp(object): def __repr__(self): return "<%s('id:%s')>" % (self.__class__.__name__, self.repo_id) + class CachedRepoList(object): def __init__(self, db_repo_list, repos_path, order_by=None): @@ -79,19 +78,18 @@ class CachedRepoList(object): def __iter__(self): for dbr in self.db_repo_list: - scmr = dbr.scm_instance_cached - # check permission at this level - if not HasRepoPermissionAny('repository.read', 'repository.write', - 'repository.admin')(dbr.repo_name, - 'get repo check'): + if not HasRepoPermissionAny( + 'repository.read', 'repository.write', 'repository.admin' + )(dbr.repo_name, 'get repo check'): continue if scmr is None: - log.error('%s this repository is present in database but it ' - 'cannot be created as an scm instance', - dbr.repo_name) + log.error( + '%s this repository is present in database but it ' + 'cannot be created as an scm instance' % dbr.repo_name + ) continue last_change = scmr.last_change @@ -103,8 +101,7 @@ class CachedRepoList(object): tmp_d['description'] = dbr.description tmp_d['description_sort'] = tmp_d['description'] tmp_d['last_change'] = last_change - tmp_d['last_change_sort'] = time.mktime(last_change \ - .timetuple()) + tmp_d['last_change_sort'] = time.mktime(last_change.timetuple()) tmp_d['tip'] = tip.raw_id tmp_d['tip_sort'] = tip.revision tmp_d['rev'] = tip.revision @@ -115,17 +112,53 @@ class CachedRepoList(object): tmp_d['last_msg'] = tip.message tmp_d['author'] = tip.author tmp_d['dbrepo'] = dbr.get_dict() - tmp_d['dbrepo_fork'] = dbr.fork.get_dict() if dbr.fork \ - else {} + tmp_d['dbrepo_fork'] = dbr.fork.get_dict() if dbr.fork else {} yield tmp_d + +class GroupList(object): + + def __init__(self, db_repo_group_list): + self.db_repo_group_list = db_repo_group_list + + def __len__(self): + return len(self.db_repo_group_list) + + def __repr__(self): + return '<%s (%s)>' % (self.__class__.__name__, self.__len__()) + + def __iter__(self): + for dbgr in self.db_repo_group_list: + # check permission at this level + if not HasReposGroupPermissionAny( + 'group.read', 'group.write', 'group.admin' + )(dbgr.group_name, 'get group repo check'): + continue + + yield dbgr + + class ScmModel(BaseModel): - """Generic Scm Model + """ + Generic Scm Model """ + def __get_repo(self, instance): + cls = Repository + if isinstance(instance, cls): + return instance + elif isinstance(instance, int) or str(instance).isdigit(): + return cls.get(instance) + elif isinstance(instance, basestring): + return cls.get_by_repo_name(instance) + elif instance: + raise Exception('given object must be int, basestr or Instance' + ' of %s got %s' % (type(cls), type(instance))) + @LazyProperty def repos_path(self): - """Get's the repositories root path from database + """ + Get's the repositories root path from database """ q = self.sa.query(RhodeCodeUi).filter(RhodeCodeUi.ui_key == '/').one() @@ -133,7 +166,8 @@ class ScmModel(BaseModel): return q.ui_value def repo_scan(self, repos_path=None): - """Listing of repositories in given path. This path should not be a + """ + Listing of repositories in given path. This path should not be a repository itself. Return a dictionary of repository objects :param repos_path: path to directory containing repositories @@ -142,19 +176,19 @@ class ScmModel(BaseModel): if repos_path is None: repos_path = self.repos_path - log.info('scanning for repositories in %s', repos_path) + log.info('scanning for repositories in %s' % repos_path) baseui = make_ui('db') - repos_list = {} + repos = {} for name, path in get_filesystem_repos(repos_path, recursive=True): - + # name need to be decomposed and put back together using the / # since this is internal storage separator for rhodecode name = Repository.url_sep().join(name.split(os.sep)) - + try: - if name in repos_list: + if name in repos: raise RepositoryError('Duplicate repository name %s ' 'found in %s' % (name, path)) else: @@ -162,17 +196,14 @@ class ScmModel(BaseModel): klass = get_backend(path[0]) if path[0] == 'hg' and path[0] in BACKENDS.keys(): - - # for mercurial we need to have an str path - repos_list[name] = klass(safe_str(path[1]), - baseui=baseui) + repos[name] = klass(safe_str(path[1]), baseui=baseui) if path[0] == 'git' and path[0] in BACKENDS.keys(): - repos_list[name] = klass(path[1]) + repos[name] = klass(path[1]) except OSError: continue - return repos_list + return repos def get_repos(self, all_repos=None, sort_key=None): """ @@ -192,30 +223,22 @@ class ScmModel(BaseModel): return repo_iter + def get_repos_groups(self, all_groups=None): + if all_groups is None: + all_groups = RepoGroup.query()\ + .filter(RepoGroup.group_parent_id == None).all() + group_iter = GroupList(all_groups) + + return group_iter + def mark_for_invalidation(self, repo_name): """Puts cache invalidation task into db for further global cache invalidation :param repo_name: this repo that should invalidation take place """ - - log.debug('marking %s for invalidation', repo_name) - cache = self.sa.query(CacheInvalidation)\ - .filter(CacheInvalidation.cache_key == repo_name).scalar() - - if cache: - # mark this cache as inactive - cache.cache_active = False - else: - log.debug('cache key not found in invalidation db -> creating one') - cache = CacheInvalidation(repo_name) - - try: - self.sa.add(cache) - self.sa.commit() - except (DatabaseError,): - log.error(traceback.format_exc()) - self.sa.rollback() + CacheInvalidation.set_invalidate(repo_name) + CacheInvalidation.set_invalidate(repo_name + "_README") def toggle_following_repo(self, follow_repo_id, user_id): @@ -224,17 +247,14 @@ class ScmModel(BaseModel): .filter(UserFollowing.user_id == user_id).scalar() if f is not None: - try: self.sa.delete(f) - self.sa.commit() action_logger(UserTemp(user_id), 'stopped_following_repo', RepoTemp(follow_repo_id)) return except: log.error(traceback.format_exc()) - self.sa.rollback() raise try: @@ -242,13 +262,12 @@ class ScmModel(BaseModel): f.user_id = user_id f.follows_repo_id = follow_repo_id self.sa.add(f) - self.sa.commit() + action_logger(UserTemp(user_id), 'started_following_repo', RepoTemp(follow_repo_id)) except: log.error(traceback.format_exc()) - self.sa.rollback() raise def toggle_following_user(self, follow_user_id, user_id): @@ -259,11 +278,9 @@ class ScmModel(BaseModel): if f is not None: try: self.sa.delete(f) - self.sa.commit() return except: log.error(traceback.format_exc()) - self.sa.rollback() raise try: @@ -271,10 +288,8 @@ class ScmModel(BaseModel): f.user_id = user_id f.follows_user_id = follow_user_id self.sa.add(f) - self.sa.commit() except: log.error(traceback.format_exc()) - self.sa.rollback() raise def is_following_repo(self, repo_name, user_id, cache=False): @@ -310,6 +325,13 @@ class ScmModel(BaseModel): return self.sa.query(Repository)\ .filter(Repository.fork_id == repo_id).count() + def mark_as_fork(self, repo, fork, user): + repo = self.__get_repo(repo) + fork = self.__get_repo(fork) + repo.fork = fork + self.sa.add(repo) + return repo + def pull_changes(self, repo_name, username): dbrepo = Repository.get_by_repo_name(repo_name) clone_uri = dbrepo.clone_uri @@ -333,13 +355,13 @@ class ScmModel(BaseModel): log.error(traceback.format_exc()) raise - def commit_change(self, repo, repo_name, cs, user, author, message, content, - f_path): + def commit_change(self, repo, repo_name, cs, user, author, message, + content, f_path): if repo.alias == 'hg': - from vcs.backends.hg import MercurialInMemoryChangeset as IMC + from rhodecode.lib.vcs.backends.hg import MercurialInMemoryChangeset as IMC elif repo.alias == 'git': - from vcs.backends.git import GitInMemoryChangeset as IMC + from rhodecode.lib.vcs.backends.git import GitInMemoryChangeset as IMC # decoding here will force that we have proper encoded values # in any other case this will throw exceptions and deny commit @@ -363,9 +385,9 @@ class ScmModel(BaseModel): def create_node(self, repo, repo_name, cs, user, author, message, content, f_path): if repo.alias == 'hg': - from vcs.backends.hg import MercurialInMemoryChangeset as IMC + from rhodecode.lib.vcs.backends.hg import MercurialInMemoryChangeset as IMC elif repo.alias == 'git': - from vcs.backends.git import GitInMemoryChangeset as IMC + from rhodecode.lib.vcs.backends.git import GitInMemoryChangeset as IMC # decoding here will force that we have proper encoded values # in any other case this will throw exceptions and deny commit @@ -400,20 +422,35 @@ class ScmModel(BaseModel): self.mark_for_invalidation(repo_name) + def get_nodes(self, repo_name, revision, root_path='/', flat=True): + """ + recursive walk in root dir and return a set of all path in that dir + based on repository walk function + + :param repo_name: name of repository + :param revision: revision for which to list nodes + :param root_path: root path to list + :param flat: return as a list, if False returns a dict with decription + + """ + _files = list() + _dirs = list() + try: + _repo = self.__get_repo(repo_name) + changeset = _repo.scm_instance.get_changeset(revision) + root_path = root_path.lstrip('/') + for topnode, dirs, files in changeset.walk(root_path): + for f in files: + _files.append(f.path if flat else {"name": f.path, + "type": "file"}) + for d in dirs: + _dirs.append(d.path if flat else {"name": d.path, + "type": "dir"}) + except RepositoryError: + log.debug(traceback.format_exc()) + raise + + return _dirs, _files def get_unread_journal(self): return self.sa.query(UserLog).count() - - def _should_invalidate(self, repo_name): - """Looks up database for invalidation signals for this repo_name - - :param repo_name: - """ - - ret = self.sa.query(CacheInvalidation)\ - .filter(CacheInvalidation.cache_key == repo_name)\ - .filter(CacheInvalidation.cache_active == False)\ - .scalar() - - return ret -