Skip to content

Commit 356672c

Browse files
junnplusAkihiroSuda
authored andcommitted
refactor: reduce duplicate code
Signed-off-by: Ye Sijun <junnplus@gmail.com> (cherry picked from commit 1ab42be) > Conflicts: > oci/spec_opts.go > oci/spec_opts_linux_test.go Signed-off-by: Akihiro Suda <akihiro.suda.cz@hco.ntt.co.jp>
1 parent 6a7b761 commit 356672c

File tree

2 files changed

+136
-36
lines changed

2 files changed

+136
-36
lines changed

oci/spec_opts.go

+18-35
Original file line numberDiff line numberDiff line change
@@ -618,11 +618,8 @@ func WithUIDGID(uid, gid uint32) SpecOpts {
618618
func WithUserID(uid uint32) SpecOpts {
619619
return func(ctx context.Context, client Client, c *containers.Container, s *Spec) (err error) {
620620
setProcess(s)
621-
if c.Snapshotter == "" && c.SnapshotKey == "" {
622-
if !isRootfsAbs(s.Root.Path) {
623-
return errors.Errorf("rootfs absolute path is required")
624-
}
625-
user, err := UserFromPath(s.Root.Path, func(u user.User) bool {
621+
setUser := func(root string) error {
622+
user, err := UserFromPath(root, func(u user.User) bool {
626623
return u.Uid == int(uid)
627624
})
628625
if err != nil {
@@ -634,7 +631,12 @@ func WithUserID(uid uint32) SpecOpts {
634631
}
635632
s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid)
636633
return nil
637-
634+
}
635+
if c.Snapshotter == "" && c.SnapshotKey == "" {
636+
if !isRootfsAbs(s.Root.Path) {
637+
return errors.New("rootfs absolute path is required")
638+
}
639+
return setUser(s.Root.Path)
638640
}
639641
if c.Snapshotter == "" {
640642
return errors.Errorf("no snapshotter set for container")
@@ -649,20 +651,7 @@ func WithUserID(uid uint32) SpecOpts {
649651
}
650652

651653
mounts = tryReadonlyMounts(mounts)
652-
return mount.WithTempMount(ctx, mounts, func(root string) error {
653-
user, err := UserFromPath(root, func(u user.User) bool {
654-
return u.Uid == int(uid)
655-
})
656-
if err != nil {
657-
if os.IsNotExist(err) || err == ErrNoUsersFound {
658-
s.Process.User.UID, s.Process.User.GID = uid, 0
659-
return nil
660-
}
661-
return err
662-
}
663-
s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid)
664-
return nil
665-
})
654+
return mount.WithTempMount(ctx, mounts, setUser)
666655
}
667656
}
668657

@@ -674,11 +663,8 @@ func WithUsername(username string) SpecOpts {
674663
return func(ctx context.Context, client Client, c *containers.Container, s *Spec) (err error) {
675664
setProcess(s)
676665
if s.Linux != nil {
677-
if c.Snapshotter == "" && c.SnapshotKey == "" {
678-
if !isRootfsAbs(s.Root.Path) {
679-
return errors.Errorf("rootfs absolute path is required")
680-
}
681-
user, err := UserFromPath(s.Root.Path, func(u user.User) bool {
666+
setUser := func(root string) error {
667+
user, err := UserFromPath(root, func(u user.User) bool {
682668
return u.Name == username
683669
})
684670
if err != nil {
@@ -687,6 +673,12 @@ func WithUsername(username string) SpecOpts {
687673
s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid)
688674
return nil
689675
}
676+
if c.Snapshotter == "" && c.SnapshotKey == "" {
677+
if !isRootfsAbs(s.Root.Path) {
678+
return errors.New("rootfs absolute path is required")
679+
}
680+
return setUser(s.Root.Path)
681+
}
690682
if c.Snapshotter == "" {
691683
return errors.Errorf("no snapshotter set for container")
692684
}
@@ -700,16 +692,7 @@ func WithUsername(username string) SpecOpts {
700692
}
701693

702694
mounts = tryReadonlyMounts(mounts)
703-
return mount.WithTempMount(ctx, mounts, func(root string) error {
704-
user, err := UserFromPath(root, func(u user.User) bool {
705-
return u.Name == username
706-
})
707-
if err != nil {
708-
return err
709-
}
710-
s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid)
711-
return nil
712-
})
695+
return mount.WithTempMount(ctx, mounts, setUser)
713696
} else if s.Windows != nil {
714697
s.Process.User.Username = username
715698
} else {

oci/spec_opts_linux_test.go

+118-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package oci
1818

1919
import (
2020
"context"
21+
"fmt"
2122
"io/ioutil"
2223
"os"
2324
"path/filepath"
@@ -31,6 +32,123 @@ import (
3132
"golang.org/x/sys/unix"
3233
)
3334

35+
// nolint:gosec
36+
func TestWithUserID(t *testing.T) {
37+
t.Parallel()
38+
39+
expectedPasswd := `root:x:0:0:root:/root:/bin/ash
40+
guest:x:405:100:guest:/dev/null:/sbin/nologin
41+
`
42+
td := t.TempDir()
43+
apply := fstest.Apply(
44+
fstest.CreateDir("/etc", 0777),
45+
fstest.CreateFile("/etc/passwd", []byte(expectedPasswd), 0777),
46+
)
47+
if err := apply.Apply(td); err != nil {
48+
t.Fatalf("failed to apply: %v", err)
49+
}
50+
c := containers.Container{ID: t.Name()}
51+
testCases := []struct {
52+
userID uint32
53+
expectedUID uint32
54+
expectedGID uint32
55+
}{
56+
{
57+
userID: 0,
58+
expectedUID: 0,
59+
expectedGID: 0,
60+
},
61+
{
62+
userID: 405,
63+
expectedUID: 405,
64+
expectedGID: 100,
65+
},
66+
{
67+
userID: 1000,
68+
expectedUID: 1000,
69+
expectedGID: 0,
70+
},
71+
}
72+
for _, testCase := range testCases {
73+
t.Run(fmt.Sprintf("user %d", testCase.userID), func(t *testing.T) {
74+
t.Parallel()
75+
s := Spec{
76+
Version: specs.Version,
77+
Root: &specs.Root{
78+
Path: td,
79+
},
80+
Linux: &specs.Linux{},
81+
}
82+
err := WithUserID(testCase.userID)(context.Background(), nil, &c, &s)
83+
assert.NoError(t, err)
84+
assert.Equal(t, testCase.expectedUID, s.Process.User.UID)
85+
assert.Equal(t, testCase.expectedGID, s.Process.User.GID)
86+
})
87+
}
88+
}
89+
90+
// nolint:gosec
91+
func TestWithUsername(t *testing.T) {
92+
t.Parallel()
93+
94+
expectedPasswd := `root:x:0:0:root:/root:/bin/ash
95+
guest:x:405:100:guest:/dev/null:/sbin/nologin
96+
`
97+
td := t.TempDir()
98+
apply := fstest.Apply(
99+
fstest.CreateDir("/etc", 0777),
100+
fstest.CreateFile("/etc/passwd", []byte(expectedPasswd), 0777),
101+
)
102+
if err := apply.Apply(td); err != nil {
103+
t.Fatalf("failed to apply: %v", err)
104+
}
105+
c := containers.Container{ID: t.Name()}
106+
testCases := []struct {
107+
user string
108+
expectedUID uint32
109+
expectedGID uint32
110+
err string
111+
}{
112+
{
113+
user: "root",
114+
expectedUID: 0,
115+
expectedGID: 0,
116+
},
117+
{
118+
user: "guest",
119+
expectedUID: 405,
120+
expectedGID: 100,
121+
},
122+
{
123+
user: "1000",
124+
err: "no users found",
125+
},
126+
{
127+
user: "unknown",
128+
err: "no users found",
129+
},
130+
}
131+
for _, testCase := range testCases {
132+
t.Run(testCase.user, func(t *testing.T) {
133+
t.Parallel()
134+
s := Spec{
135+
Version: specs.Version,
136+
Root: &specs.Root{
137+
Path: td,
138+
},
139+
Linux: &specs.Linux{},
140+
}
141+
err := WithUsername(testCase.user)(context.Background(), nil, &c, &s)
142+
if err != nil {
143+
assert.EqualError(t, err, testCase.err)
144+
}
145+
assert.Equal(t, testCase.expectedUID, s.Process.User.UID)
146+
assert.Equal(t, testCase.expectedGID, s.Process.User.GID)
147+
})
148+
}
149+
150+
}
151+
34152
// nolint:gosec
35153
func TestWithAdditionalGIDs(t *testing.T) {
36154
t.Parallel()
@@ -55,7 +173,6 @@ sys:x:3:root,bin,adm
55173
c := containers.Container{ID: t.Name()}
56174

57175
testCases := []struct {
58-
name string
59176
user string
60177
expected []uint32
61178
}{

0 commit comments

Comments
 (0)