From 9daa535c9cff0a3862a60b02c5ec931452651bb5 Mon Sep 17 00:00:00 2001 From: Harry-zklcdc Date: Mon, 13 Jan 2025 19:06:51 +0800 Subject: [PATCH] =?UTF-8?q?[Perf]=20=F0=9F=9A=80=20Set=20Jupter=20Token=20?= =?UTF-8?q?for=20Safety=20Improve=20#7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- services/instanceController/create.go | 2 + services/instanceController/execute.go | 85 ++++++++++++++++++++++++++ services/instanceController/patch.go | 4 ++ services/instanceController/restart.go | 2 + 4 files changed, 93 insertions(+) diff --git a/services/instanceController/create.go b/services/instanceController/create.go index 131c9eb..aed974d 100644 --- a/services/instanceController/create.go +++ b/services/instanceController/create.go @@ -47,6 +47,8 @@ func Create(instance *models.Instances) (containerName, volumeName string, err e return "", "", err } + go SetJupterPassword(server.IP, server.Port, server.Apikey, containerName, instance.SshPasswd) + portBindings, err := GetPortForward(server.IP, server.Port, server.Apikey, containerName) if err != nil { deleteInstance(server.IP, server.Port, server.Apikey, containerName) diff --git a/services/instanceController/execute.go b/services/instanceController/execute.go index 20a7dd9..7a64962 100644 --- a/services/instanceController/execute.go +++ b/services/instanceController/execute.go @@ -56,3 +56,88 @@ func SetRootPassword(ip string, port int, apikey string, return nil } + +func SetJupterPassword(ip string, port int, apikey string, + containerName, password string) (err error) { + l.SetFunction("SetJupterPassword") + + // Set Jupyter Password + data := executeReq{ + Cmd: []string{ + "sed", + "-i", + "/^c.ServerApp.token = ''/c\\c.ServerApp.token = '" + password + "'", + "/root/.jupyter/jupyter_notebook_config.py", + }, + } + + reqBytes, err := json.Marshal(data) + if err != nil { + l.Error("marshal request data error: %v", err) + return + } + + c := request.NewRequest().Post(). + SetUrl("http://" + ip + ":" + strconv.Itoa(port) + apiPrefix + instancePrefix + "/" + containerName + instanceExecute). + SetAuthorization("Bearer " + apikey). + SetUserAgent("megrez"). + SetBody(bytes.NewBuffer(reqBytes)) + c.Do() + + if c.GetStatusCode() != 200 { + l.Error("set jupter password error: %d", c.GetStatusCode()) + return errors.New("set jupter password request error") + } + + var res resStruct + err = json.Unmarshal(c.GetBody(), &res) + if err != nil { + l.Error("unmarshal response data error: %v", err) + return err + } + + if res.Code != 200 { + l.Error("set jupter password code: %d, error: %s", res.Code, res.Msg) + return errors.New(res.Msg) + } + + // Restart Jupyter + data = executeReq{ + Cmd: []string{ + "service", + "jupyter", + "restart", + }, + } + + reqBytes, err = json.Marshal(data) + if err != nil { + l.Error("marshal request data error: %v", err) + return + } + + c = request.NewRequest().Post(). + SetUrl("http://" + ip + ":" + strconv.Itoa(port) + apiPrefix + instancePrefix + "/" + containerName + instanceExecute). + SetAuthorization("Bearer " + apikey). + SetUserAgent("megrez"). + SetBody(bytes.NewBuffer(reqBytes)) + c.Do() + + if c.GetStatusCode() != 200 { + l.Error("restart jupter error: %d", c.GetStatusCode()) + return errors.New("restart jupter request error") + } + + err = json.Unmarshal(c.GetBody(), &res) + if err != nil { + l.Error("unmarshal response data error: %v", err) + return err + } + + if res.Code != 200 { + l.Error("restart jupter code: %d, error: %s", res.Code, res.Msg) + return errors.New(res.Msg) + } + + return nil +} diff --git a/services/instanceController/patch.go b/services/instanceController/patch.go index b1df6ae..e3861ea 100644 --- a/services/instanceController/patch.go +++ b/services/instanceController/patch.go @@ -95,6 +95,8 @@ func Patch(instance *models.Instances, gpuCount, volumeSize int, cpuOnly bool) ( return err } + go SetJupterPassword(server.IP, server.Port, server.Apikey, instance.ContainerName, instance.SshPasswd) + portBindings, err := GetPortForward(server.IP, server.Port, server.Apikey, instance.ContainerName) if err != nil { l.Error("get port forward error: %v", err) @@ -138,6 +140,8 @@ func Patch(instance *models.Instances, gpuCount, volumeSize int, cpuOnly bool) ( return err } + go SetJupterPassword(server.IP, server.Port, server.Apikey, instance.ContainerName, instance.SshPasswd) + portBindings, err := GetPortForward(server.IP, server.Port, server.Apikey, instance.ContainerName) if err != nil { l.Error("get port forward error: %v", err) diff --git a/services/instanceController/restart.go b/services/instanceController/restart.go index f6b29c1..9f729c9 100644 --- a/services/instanceController/restart.go +++ b/services/instanceController/restart.go @@ -48,6 +48,8 @@ func Restart(instance *models.Instances) (err error) { return err } + go SetJupterPassword(server.IP, server.Port, server.Apikey, instance.ContainerName, instance.SshPasswd) + portBindings, err := GetPortForward(server.IP, server.Port, server.Apikey, instance.ContainerName) if err != nil { l.Error("get port forward error: %v", err)