feat(gitstore): honor configured branch and follow live remote default

This commit is contained in:
Duong M. CUONG
2026-04-02 14:44:44 +00:00
parent 96f55570f7
commit 058793c73a
6 changed files with 838 additions and 6 deletions
+242 -5
View File
@@ -32,16 +32,24 @@ type GitTokenStore struct {
repoDir string
configDir string
remote string
branch string
username string
password string
lastGC time.Time
}
type resolvedRemoteBranch struct {
name plumbing.ReferenceName
hash plumbing.Hash
}
// NewGitTokenStore creates a token store that saves credentials to disk through the
// TokenStorage implementation embedded in the token record.
func NewGitTokenStore(remote, username, password string) *GitTokenStore {
// When branch is non-empty, clone/pull/push operations target that branch instead of the remote default.
func NewGitTokenStore(remote, username, password, branch string) *GitTokenStore {
return &GitTokenStore{
remote: remote,
branch: strings.TrimSpace(branch),
username: username,
password: password,
}
@@ -120,7 +128,11 @@ func (s *GitTokenStore) EnsureRepository() error {
s.dirLock.Unlock()
return fmt.Errorf("git token store: create repo dir: %w", errMk)
}
if _, errClone := git.PlainClone(repoDir, &git.CloneOptions{Auth: authMethod, URL: s.remote}); errClone != nil {
cloneOpts := &git.CloneOptions{Auth: authMethod, URL: s.remote}
if s.branch != "" {
cloneOpts.ReferenceName = plumbing.NewBranchReferenceName(s.branch)
}
if _, errClone := git.PlainClone(repoDir, cloneOpts); errClone != nil {
if errors.Is(errClone, transport.ErrEmptyRemoteRepository) {
_ = os.RemoveAll(gitDir)
repo, errInit := git.PlainInit(repoDir, false)
@@ -128,6 +140,13 @@ func (s *GitTokenStore) EnsureRepository() error {
s.dirLock.Unlock()
return fmt.Errorf("git token store: init empty repo: %w", errInit)
}
if s.branch != "" {
headRef := plumbing.NewSymbolicReference(plumbing.HEAD, plumbing.NewBranchReferenceName(s.branch))
if errHead := repo.Storer.SetReference(headRef); errHead != nil {
s.dirLock.Unlock()
return fmt.Errorf("git token store: set head to branch %s: %w", s.branch, errHead)
}
}
if _, errRemote := repo.Remote("origin"); errRemote != nil {
if _, errCreate := repo.CreateRemote(&config.RemoteConfig{
Name: "origin",
@@ -176,16 +195,39 @@ func (s *GitTokenStore) EnsureRepository() error {
s.dirLock.Unlock()
return fmt.Errorf("git token store: worktree: %w", errWorktree)
}
if errPull := worktree.Pull(&git.PullOptions{Auth: authMethod, RemoteName: "origin"}); errPull != nil {
if s.branch != "" {
if errCheckout := s.checkoutConfiguredBranch(repo, worktree, authMethod); errCheckout != nil {
s.dirLock.Unlock()
return errCheckout
}
} else {
// When branch is unset, ensure the working tree follows the remote default branch
if err := checkoutRemoteDefaultBranch(repo, worktree, authMethod); err != nil {
if !shouldFallbackToCurrentBranch(repo, err) {
s.dirLock.Unlock()
return fmt.Errorf("git token store: checkout remote default: %w", err)
}
}
}
pullOpts := &git.PullOptions{Auth: authMethod, RemoteName: "origin"}
if s.branch != "" {
pullOpts.ReferenceName = plumbing.NewBranchReferenceName(s.branch)
}
if errPull := worktree.Pull(pullOpts); errPull != nil {
switch {
case errors.Is(errPull, git.NoErrAlreadyUpToDate),
errors.Is(errPull, git.ErrUnstagedChanges),
errors.Is(errPull, git.ErrNonFastForwardUpdate):
// Ignore clean syncs, local edits, and remote divergence—local changes win.
case errors.Is(errPull, transport.ErrAuthenticationRequired),
errors.Is(errPull, plumbing.ErrReferenceNotFound),
errors.Is(errPull, transport.ErrEmptyRemoteRepository):
// Ignore authentication prompts and empty remote references on initial sync.
case errors.Is(errPull, plumbing.ErrReferenceNotFound):
if s.branch != "" {
s.dirLock.Unlock()
return fmt.Errorf("git token store: pull: %w", errPull)
}
// Ignore missing references only when following the remote default branch.
default:
s.dirLock.Unlock()
return fmt.Errorf("git token store: pull: %w", errPull)
@@ -553,6 +595,192 @@ func (s *GitTokenStore) relativeToRepo(path string) (string, error) {
return rel, nil
}
func (s *GitTokenStore) checkoutConfiguredBranch(repo *git.Repository, worktree *git.Worktree, authMethod transport.AuthMethod) error {
branchRefName := plumbing.NewBranchReferenceName(s.branch)
headRef, errHead := repo.Head()
switch {
case errHead == nil && headRef.Name() == branchRefName:
return nil
case errHead != nil && !errors.Is(errHead, plumbing.ErrReferenceNotFound):
return fmt.Errorf("git token store: get head: %w", errHead)
}
if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName}); err == nil {
return nil
} else if _, errRef := repo.Reference(branchRefName, true); errRef == nil {
return fmt.Errorf("git token store: checkout branch %s: %w", s.branch, err)
} else if !errors.Is(errRef, plumbing.ErrReferenceNotFound) {
return fmt.Errorf("git token store: inspect branch %s: %w", s.branch, errRef)
} else if err := s.checkoutConfiguredRemoteTrackingBranch(repo, worktree, branchRefName, authMethod); err != nil {
return fmt.Errorf("git token store: checkout branch %s: %w", s.branch, err)
}
return nil
}
func (s *GitTokenStore) checkoutConfiguredRemoteTrackingBranch(repo *git.Repository, worktree *git.Worktree, branchRefName plumbing.ReferenceName, authMethod transport.AuthMethod) error {
remoteRefName := plumbing.ReferenceName("refs/remotes/origin/" + s.branch)
remoteRef, err := repo.Reference(remoteRefName, true)
if errors.Is(err, plumbing.ErrReferenceNotFound) {
if errSync := syncRemoteReferences(repo, authMethod); errSync != nil {
return fmt.Errorf("sync remote refs: %w", errSync)
}
remoteRef, err = repo.Reference(remoteRefName, true)
}
if err != nil {
return err
}
if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName, Create: true, Hash: remoteRef.Hash()}); err != nil {
return err
}
cfg, err := repo.Config()
if err != nil {
return fmt.Errorf("git token store: repo config: %w", err)
}
if _, ok := cfg.Branches[s.branch]; !ok {
cfg.Branches[s.branch] = &config.Branch{Name: s.branch}
}
cfg.Branches[s.branch].Remote = "origin"
cfg.Branches[s.branch].Merge = branchRefName
if err := repo.SetConfig(cfg); err != nil {
return fmt.Errorf("git token store: set branch config: %w", err)
}
return nil
}
func syncRemoteReferences(repo *git.Repository, authMethod transport.AuthMethod) error {
if err := repo.Fetch(&git.FetchOptions{Auth: authMethod, RemoteName: "origin"}); err != nil && !errors.Is(err, git.NoErrAlreadyUpToDate) {
return err
}
return nil
}
// resolveRemoteDefaultBranch queries the origin remote to determine the remote's default branch
// (the target of HEAD) and returns the corresponding local branch reference name (e.g. refs/heads/master).
func resolveRemoteDefaultBranch(repo *git.Repository, authMethod transport.AuthMethod) (resolvedRemoteBranch, error) {
if err := syncRemoteReferences(repo, authMethod); err != nil {
return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: sync remote refs: %w", err)
}
remote, err := repo.Remote("origin")
if err != nil {
return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: get remote: %w", err)
}
refs, err := remote.List(&git.ListOptions{Auth: authMethod})
if err != nil {
if resolved, ok := resolveRemoteDefaultBranchFromLocal(repo); ok {
return resolved, nil
}
return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: list remote refs: %w", err)
}
for _, r := range refs {
if r.Name() == plumbing.HEAD {
if r.Type() == plumbing.SymbolicReference {
if target, ok := normalizeRemoteBranchReference(r.Target()); ok {
return resolvedRemoteBranch{name: target}, nil
}
}
s := r.String()
if idx := strings.Index(s, "->"); idx != -1 {
if target, ok := normalizeRemoteBranchReference(plumbing.ReferenceName(strings.TrimSpace(s[idx+2:]))); ok {
return resolvedRemoteBranch{name: target}, nil
}
}
}
}
if resolved, ok := resolveRemoteDefaultBranchFromLocal(repo); ok {
return resolved, nil
}
for _, r := range refs {
if normalized, ok := normalizeRemoteBranchReference(r.Name()); ok {
return resolvedRemoteBranch{name: normalized, hash: r.Hash()}, nil
}
}
return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: remote default branch not found")
}
func resolveRemoteDefaultBranchFromLocal(repo *git.Repository) (resolvedRemoteBranch, bool) {
ref, err := repo.Reference(plumbing.ReferenceName("refs/remotes/origin/HEAD"), true)
if err != nil || ref.Type() != plumbing.SymbolicReference {
return resolvedRemoteBranch{}, false
}
target, ok := normalizeRemoteBranchReference(ref.Target())
if !ok {
return resolvedRemoteBranch{}, false
}
return resolvedRemoteBranch{name: target}, true
}
func normalizeRemoteBranchReference(name plumbing.ReferenceName) (plumbing.ReferenceName, bool) {
switch {
case strings.HasPrefix(name.String(), "refs/heads/"):
return name, true
case strings.HasPrefix(name.String(), "refs/remotes/origin/"):
return plumbing.NewBranchReferenceName(strings.TrimPrefix(name.String(), "refs/remotes/origin/")), true
default:
return "", false
}
}
func shouldFallbackToCurrentBranch(repo *git.Repository, err error) bool {
if !errors.Is(err, transport.ErrAuthenticationRequired) && !errors.Is(err, transport.ErrEmptyRemoteRepository) {
return false
}
_, headErr := repo.Head()
return headErr == nil
}
// checkoutRemoteDefaultBranch ensures the working tree is checked out to the remote's default branch
// (the branch target of origin/HEAD). If the local branch does not exist it will be created to track
// the remote branch.
func checkoutRemoteDefaultBranch(repo *git.Repository, worktree *git.Worktree, authMethod transport.AuthMethod) error {
resolved, err := resolveRemoteDefaultBranch(repo, authMethod)
if err != nil {
return err
}
branchRefName := resolved.name
// If HEAD already points to the desired branch, nothing to do.
headRef, errHead := repo.Head()
if errHead == nil && headRef.Name() == branchRefName {
return nil
}
// If local branch exists, attempt a checkout
if _, err := repo.Reference(branchRefName, true); err == nil {
if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName}); err != nil {
return fmt.Errorf("checkout branch %s: %w", branchRefName.String(), err)
}
return nil
}
// Try to find the corresponding remote tracking ref (refs/remotes/origin/<name>)
branchShort := strings.TrimPrefix(branchRefName.String(), "refs/heads/")
remoteRefName := plumbing.ReferenceName("refs/remotes/origin/" + branchShort)
hash := resolved.hash
if remoteRef, err := repo.Reference(remoteRefName, true); err == nil {
hash = remoteRef.Hash()
} else if err != nil && !errors.Is(err, plumbing.ErrReferenceNotFound) {
return fmt.Errorf("checkout remote default: remote ref %s: %w", remoteRefName.String(), err)
}
if hash == plumbing.ZeroHash {
return fmt.Errorf("checkout remote default: remote ref %s not found", remoteRefName.String())
}
if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName, Create: true, Hash: hash}); err != nil {
return fmt.Errorf("checkout create branch %s: %w", branchRefName.String(), err)
}
cfg, err := repo.Config()
if err != nil {
return fmt.Errorf("git token store: repo config: %w", err)
}
if _, ok := cfg.Branches[branchShort]; !ok {
cfg.Branches[branchShort] = &config.Branch{Name: branchShort}
}
cfg.Branches[branchShort].Remote = "origin"
cfg.Branches[branchShort].Merge = branchRefName
if err := repo.SetConfig(cfg); err != nil {
return fmt.Errorf("git token store: set branch config: %w", err)
}
return nil
}
func (s *GitTokenStore) commitAndPushLocked(message string, relPaths ...string) error {
repoDir := s.repoDirSnapshot()
if repoDir == "" {
@@ -618,7 +846,16 @@ func (s *GitTokenStore) commitAndPushLocked(message string, relPaths ...string)
return errRewrite
}
s.maybeRunGC(repo)
if err = repo.Push(&git.PushOptions{Auth: s.gitAuth(), Force: true}); err != nil {
pushOpts := &git.PushOptions{Auth: s.gitAuth(), Force: true}
if s.branch != "" {
pushOpts.RefSpecs = []config.RefSpec{config.RefSpec("refs/heads/" + s.branch + ":refs/heads/" + s.branch)}
} else {
// When branch is unset, pin push to the currently checked-out branch.
if headRef, err := repo.Head(); err == nil {
pushOpts.RefSpecs = []config.RefSpec{config.RefSpec(headRef.Name().String() + ":" + headRef.Name().String())}
}
}
if err = repo.Push(pushOpts); err != nil {
if errors.Is(err, git.NoErrAlreadyUpToDate) {
return nil
}