wisemodel-to-hf / git_commd.py
chenlei1983's picture
Update git_commd.py
d013f47 verified
import os
import subprocess
from typing import List, Optional
from logger import get_logger
logger = get_logger()
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton,
cls).__call__(*args, **kwargs)
return cls._instances[cls]
class GitError(Exception):
pass
class GitCommandWrapper(metaclass=Singleton):
"""Some git operation wrapper
"""
default_git_path = 'git' # The default git command line
git_path='git'
MODEL_BRANCH="main"
def __init__(self, path: str = None):
self.git_path = path or self.default_git_path
def _run_git_command(self, *args) -> subprocess.CompletedProcess:
"""Run git command, if command return 0, return subprocess.response
otherwise raise GitError, message is stdout and stderr.
Args:
args: List of command args.
Raises:
GitError: Exception with stdout and stderr.
Returns:
subprocess.CompletedProcess: the command response
"""
logger.debug(' '.join(args))
git_env = os.environ.copy()
git_env['GIT_TERMINAL_PROMPT'] = '0'
response = subprocess.run(
[self.git_path, *args],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=git_env,
) # compatible for python3.6
try:
response.check_returncode()
return response
except subprocess.CalledProcessError as error:
logger.error('There are error run git command.')
raise GitError(
'stdout: %s, stderr: %s' %
(response.stdout.decode('utf8'), error.stderr.decode('utf8')))
def config_auth_token(self, repo_dir, auth_token):
url = self.get_repo_remote_url(repo_dir)
if '//oauth2' not in url:
auth_url = self._add_token(auth_token, url)
cmd_args = '-C %s remote set-url origin %s' % (repo_dir, auth_url)
cmd_args = cmd_args.split(' ')
rsp = self._run_git_command(*cmd_args)
logger.debug(rsp.stdout.decode('utf8'))
def _add_token(self, token: str, url: str):
if token:
if '//oauth2' not in url:
url = url.replace('//', '//oauth2:%s@' % token)
return url
def remove_token_from_url(self, url: str):
if url and '//oauth2' in url:
start_index = url.find('oauth2')
end_index = url.find('@')
url = url[:start_index] + url[end_index + 1:]
return url
def is_lfs_installed(self):
cmd = ['lfs', 'env']
try:
self._run_git_command(*cmd)
return True
except GitError:
return False
def git_lfs_install(self, repo_dir):
cmd = ['-C', repo_dir, 'lfs', 'install']
try:
self._run_git_command(*cmd)
return True
except GitError:
return False
def clone(self,
repo_base_dir: str,
token: str,
url: str,
repo_name: str,
branch: Optional[str] = None):
""" git clone command wrapper.
For public project, token can None, private repo, there must token.
Args:
repo_base_dir (str): The local base dir, the repository will be clone to local_dir/repo_name
token (str): The git token, must be provided for private project.
url (str): The remote url
repo_name (str): The local repository path name.
branch (str, optional): _description_. Defaults to None.
Returns:
The popen response.
"""
url = self._add_token(token, url)
if branch:
clone_args = '-C %s clone %s %s --branch %s' % (repo_base_dir, url,
repo_name, branch)
else:
clone_args = '-C %s clone %s' % (repo_base_dir, url)
logger.debug(clone_args)
clone_args = clone_args.split(' ')
response = self._run_git_command(*clone_args)
logger.debug(response.stdout.decode('utf8'))
return response
def add(self,
repo_dir: str,
files: List[str] = list(),
all_files: bool = False):
if all_files:
add_args = '-C %s add -A' % repo_dir
elif len(files) > 0:
files_str = ' '.join(files)
add_args = '-C %s add %s' % (repo_dir, files_str)
add_args = add_args.split(' ')
rsp = self._run_git_command(*add_args)
logger.debug(rsp.stdout.decode('utf8'))
return rsp
def commit(self, repo_dir: str, message: str):
"""Run git commit command
Args:
repo_dir (str): the repository directory.
message (str): commit message.
Returns:
The command popen response.
"""
commit_args = ['-C', '%s' % repo_dir, 'commit', '-m', "'%s'" % message]
rsp = self._run_git_command(*commit_args)
logger.info(rsp.stdout.decode('utf8'))
return rsp
def checkout(self, repo_dir: str, revision: str):
cmds = ['-C', '%s' % repo_dir, 'checkout', '%s' % revision]
return self._run_git_command(*cmds)
def new_branch(self, repo_dir: str, revision: str):
cmds = ['-C', '%s' % repo_dir, 'checkout', '-b', revision]
return self._run_git_command(*cmds)
def get_remote_branches(self, repo_dir: str):
cmds = ['-C', '%s' % repo_dir, 'branch', '-r']
rsp = self._run_git_command(*cmds)
info = [
line.strip()
for line in rsp.stdout.decode('utf8').strip().split(os.linesep)
]
if len(info) == 1:
return ['/'.join(info[0].split('/')[1:])]
else:
return ['/'.join(line.split('/')[1:]) for line in info[1:]]
def pull(self,
repo_dir: str,
remote: str = 'origin',
branch: str = 'main'):
cmds = ['-C', repo_dir, 'pull', remote, branch]
return self._run_git_command(*cmds)
def push(self,
repo_dir: str,
token: str,
url: str,
local_branch: str,
remote_branch: str,
force: bool = False):
url = self._add_token(token, url)
push_args = '-C %s push %s %s:%s' % (repo_dir, url, local_branch,
remote_branch)
if force:
push_args += ' -f'
push_args = push_args.split(' ')
rsp = self._run_git_command(*push_args)
logger.debug(rsp.stdout.decode('utf8'))
return rsp
def get_repo_remote_url(self, repo_dir: str):
cmd_args = '-C %s config --get remote.origin.url' % repo_dir
cmd_args = cmd_args.split(' ')
rsp = self._run_git_command(*cmd_args)
url = rsp.stdout.decode('utf8')
return url.strip()
def list_lfs_files(self, repo_dir: str):
cmd_args = '-C %s lfs ls-files' % repo_dir
cmd_args = cmd_args.split(' ')
rsp = self._run_git_command(*cmd_args)
out = rsp.stdout.decode('utf8').strip()
files = []
for line in out.split(os.linesep):
files.append(line.split(' ')[-1])
return files
def tag(self,
repo_dir: str,
tag_name: str,
message: str,
ref: str = MODEL_BRANCH):
cmd_args = [
'-C', repo_dir, 'tag', tag_name, '-m',
'"%s"' % message, ref
]
rsp = self._run_git_command(*cmd_args)
logger.debug(rsp.stdout.decode('utf8'))
return rsp
def push_tag(self, repo_dir: str, tag_name):
cmd_args = ['-C', repo_dir, 'push', 'origin', tag_name]
rsp = self._run_git_command(*cmd_args)
logger.debug(rsp.stdout.decode('utf8'))
return rsp