Skip to content

Commit

Permalink
Merge pull request #2429 from helinwang/master_client
Browse files Browse the repository at this point in the history
implement master server client, remove unnecessary dummy variable
  • Loading branch information
helinwang authored Jun 14, 2017
2 parents 0e2acb8 + 13867a0 commit 5f5e128
Show file tree
Hide file tree
Showing 7 changed files with 449 additions and 142 deletions.
52 changes: 2 additions & 50 deletions go/cmd/master/master.go
Original file line number Diff line number Diff line change
@@ -1,80 +1,32 @@
package main

import (
"fmt"
"net"
"net/http"
"net/rpc"
"os"
"path/filepath"
"strconv"
"strings"
"time"

"github.com/namsral/flag"

"github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/recordio"
)

func main() {
port := flag.Int("port", 8080, "port of the master server.")
dataset := flag.String("training_dataset", "", "dataset: comma separated path to RecordIO paths, supports golb patterns.")

faultTolerance := flag.Bool("fault_tolerance", false, "enable fault tolerance (requires etcd).")
taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.")
taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.")
chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.")
flag.Parse()

if *dataset == "" {
panic("no dataset specified.")
}

if *faultTolerance {
panic("fault tolernance not implemented.")
}

var chunks []master.Chunk
var paths []string
ss := strings.Split(*dataset, ",")
fmt.Println(ss)
for _, s := range ss {
match, err := filepath.Glob(s)
if err != nil {
panic(err)
}
paths = append(paths, match...)
}

if len(paths) == 0 {
panic("no valid datset specified.")
}

idx := 0
for _, path := range paths {
f, err := os.Open(path)
if err != nil {
panic(err)
}

index, err := recordio.LoadIndex(f)
if err != nil {
panic(err)
}
f.Close()

count := index.NumChunks()
for i := 0; i < count; i++ {
chunk := master.Chunk{
Idx: idx,
Path: path,
Index: *index.ChunkIndex(i),
}
chunks = append(chunks, chunk)
}
}

s := master.NewService(chunks, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
s := master.NewService(*chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
err := rpc.Register(s)
if err != nil {
panic(err)
Expand Down
21 changes: 21 additions & 0 deletions go/pserver/internal/connection/conn.go → go/connection/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package connection

import (
"errors"
"log"
"net/rpc"
"sync"
)
Expand All @@ -21,6 +22,18 @@ func New() *Conn {
return c
}

// Close closes the connection.
func (c *Conn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()

if c.client == nil {
return nil
}

return c.client.Close()
}

// Connect connects the connection to a address.
func (c *Conn) Connect(addr string) error {
c.mu.Lock()
Expand Down Expand Up @@ -50,12 +63,20 @@ func (c *Conn) Connect(addr string) error {
c.waitConn = nil
}
} else {
err := client.Close()
if err != nil {
log.Println(err)
}

return errors.New("client already set from a concurrent goroutine")
}

return nil
}

// TODO(helin): refactor Call to be able to perform given retry
// policy.

// Call make a RPC call.
//
// Call will be blocked until the connection to remote RPC service
Expand Down
82 changes: 82 additions & 0 deletions go/master/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package master

import (
"log"
"time"

"github.com/PaddlePaddle/Paddle/go/connection"
)

// Addresser provide the address of the master server.
type Addresser interface {
Address() string
}

// Client is the client of the master server.
type Client struct {
conn *connection.Conn
}

// NewClient creates a new Client.
func NewClient(addr Addresser) *Client {
c := &Client{}
c.conn = connection.New()
go c.monitorMaster(addr)
return c
}

func (c *Client) monitorMaster(addr Addresser) {
lastMaster := ""
monitor := func() {
// get the lastest address of the master server,
// connect to the new address once address changed.
curMaster := addr.Address()
if curMaster != lastMaster {
if curMaster == "" {
err := c.conn.Close()
if err != nil {
log.Println(err)
}
} else {
err := c.conn.Connect(curMaster)
if err != nil {
log.Println(err)

// connect to addr failed, set
// to last known addr in order
// to retry next time.
curMaster = lastMaster
}

}
}

lastMaster = curMaster
}

monitor()
ticker := time.NewTicker(10 * time.Second)
for _ = range ticker.C {
monitor()
}
}

// SetDataset set dataset for the master server to dispatch.
//
// SetDataset can be call multiple times from different nodes. But
// only the first call will be honored.
func (c *Client) SetDataset(globPaths []string) error {
return c.conn.Call("Service.SetDataset", globPaths, nil)
}

// GetTask gets a new task from the master server.
func (c *Client) GetTask() (Task, error) {
var t Task
err := c.conn.Call("Service.GetTask", 0, &t)
return t, err
}

// TaskFinished tells the master server a task is finished.
func (c *Client) TaskFinished(taskID int) error {
return c.conn.Call("Service.TaskFinished", taskID, nil)
}
120 changes: 120 additions & 0 deletions go/master/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package master_test

import (
"fmt"
"net"
"net/http"
"net/rpc"
"os"
"strconv"
"strings"
"testing"
"time"

log "github.com/sirupsen/logrus"

"github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/recordio"
)

const (
totalTask = 20
chunkPerTask = 10
)

var port int

func init() {
log.SetLevel(log.ErrorLevel)

l, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}

ss := strings.Split(l.Addr().String(), ":")
p, err := strconv.Atoi(ss[len(ss)-1])
if err != nil {
panic(err)
}
port = p

go func(l net.Listener) {
s := master.NewService(chunkPerTask, time.Second, 1)
server := rpc.NewServer()
err := server.Register(s)
if err != nil {
panic(err)
}

mux := http.NewServeMux()
mux.Handle(rpc.DefaultRPCPath, server)
err = http.Serve(l, mux)
if err != nil {
panic(err)
}
}(l)
}

type addresser string

func (a addresser) Address() string {
return string(a)
}

func TestClientFull(t *testing.T) {
const p = "/tmp/master_client_test_0"
f, err := os.Create(p)
if err != nil {
panic(err)
}

for i := 0; i < totalTask*chunkPerTask; i++ {
w := recordio.NewWriter(f, -1, -1)
w.Write(nil)
// call Close to force RecordIO writing a chunk.
w.Close()
}
f.Close()

c := master.NewClient(addresser(fmt.Sprintf(":%d", port)))
c.SetDataset([]string{p})

checkOnePass := func(i int) {
var tasks []master.Task
for i := 0; i < totalTask; i++ {
task, err := c.GetTask()
if err != nil {
t.Fatal(i, err)
}
tasks = append(tasks, task)
}

_, err = c.GetTask()
if err == nil {
t.Fatal(i, "should get error.")
}

err = c.TaskFinished(tasks[0].ID)
if err != nil {
t.Fatal(err)
}
tasks = tasks[1:]
task, err := c.GetTask()
if err != nil {
t.Fatal(err)
}
tasks = append(tasks, task)

for _, task := range tasks {
err = c.TaskFinished(task.ID)
if err != nil {
t.Fatal(i, err)
}
}
}

for i := 0; i < 10; i++ {
checkOnePass(i)
}
}
Loading

0 comments on commit 5f5e128

Please sign in to comment.