megrez/services/instanceController/patch.go

166 lines
4.8 KiB
Go

package instanceController
import (
"errors"
"megrez/models"
"megrez/services/database"
"strings"
)
type patchReqStruct struct {
CpuPatch *cpuPatchStruct `json:"cpuPatch"`
GpuPatch *gpuPatchStruct `json:"gpuPatch"`
MemoryPatch *MemoryPatchStruct `json:"memoryPatch"`
VolumePatch *volumePatchStruct `json:"volumePatch"`
}
type cpuPatchStruct struct {
CpuCount int `json:"cpuCount"`
}
type gpuPatchStruct struct {
GpuCount int `json:"gpuCount"`
}
type MemoryPatchStruct struct {
Memory string `json:"memory"`
}
type volumePatchStruct struct {
OldBind bindStruct `json:"oldBind"`
NewBind bindStruct `json:"newBind"`
}
func Patch(instance *models.Instances, gpuCount, volumeSize int, cpuOnly bool) (err error) {
l.SetFunction("Patch")
instance.Status = models.InstanceStatusModifying
result := database.DB.Save(&instance)
if result.Error != nil {
l.Error("save instance error: %v", result.Error)
return result.Error
}
if gpuCount == instance.GpuCount && volumeSize == instance.VolumeSize && cpuOnly == instance.CpuOnly {
instance.Status = models.InstanceStatusStopped
result = database.DB.Save(&instance)
if result.Error != nil {
l.Error("save instance error: %v", result.Error)
return result.Error
}
return errors.New("no change")
}
server := models.Servers{
ID: instance.ServerID,
}
result = database.DB.First(&server)
if result.Error != nil {
l.Error("query server error: %v", result.Error)
return result.Error
}
oldVolumeName := instance.VolumeName
if volumeSize != instance.VolumeSize {
newVolumeName, err := patchVolume(server.IP, server.Port, server.Apikey, strings.Split(instance.VolumeName, "-")[0], volumeSize)
if err != nil {
l.Error("patch volume error: %v", err)
return err
}
instance.VolumeName = newVolumeName
instance.VolumeSize = volumeSize
defer func(server models.Servers, instance *models.Instances, volumeName string) {
err := deleteVolume(server.IP, server.Port, server.Apikey, volumeName, true)
if err != nil {
l.Error("delete volume error: %v", err)
}
}(server, instance, oldVolumeName)
}
if cpuOnly {
err = patchCpuOnly(server.IP, server.Port, server.Apikey, instance.ContainerName, instance.VolumeName, oldVolumeName)
if err != nil {
l.Error("patch cpu only error: %v", err)
return err
}
err = SetRootPassword(server.IP, server.Port, server.Apikey, instance.ContainerName, instance.SshPasswd)
if err != nil {
l.Error("set root password error: %v", err)
return err
}
go func() {
SetJupterPassword(server.IP, server.Port, server.Apikey, instance.ContainerName, instance.SshPasswd)
SetCodeServerPassword(server.IP, server.Port, server.Apikey, instance.ContainerName, instance.SshPasswd)
}()
portBindings, err := GetPortForward(server.IP, server.Port, server.Apikey, instance.ContainerName)
if err != nil {
l.Error("get port forward error: %v", err)
return err
}
instance.SshAddress = server.IP + ":" + portBindings["22"]
instance.TensorBoardAddress = server.IP + ":" + portBindings["6007"]
instance.JupyterAddress = server.IP + ":" + portBindings["8888"]
instance.GrafanaAddress = server.IP + ":" + portBindings["3000"]
instance.CpuOnly = true
instance.GpuCount = 0
instance.Status = models.InstanceStatusRunning
result = database.DB.Save(&instance)
if result.Error != nil {
l.Error("save instance error: %v", result.Error)
return result.Error
}
return nil
}
err = patchGpu(server.IP, server.Port, server.Apikey,
instance.ContainerName,
server.CpuCountPerGpu, server.MemoryPerGpu,
instance.VolumeName, oldVolumeName,
gpuCount, instance.GpuCount,
)
if err != nil {
l.Error("patch gpu error: %v", err)
return err
}
err = SetRootPassword(server.IP, server.Port, server.Apikey, instance.ContainerName, instance.SshPasswd)
if err != nil {
l.Error("set root password error: %v", err)
return err
}
go func() {
SetJupterPassword(server.IP, server.Port, server.Apikey, instance.ContainerName, instance.SshPasswd)
SetCodeServerPassword(server.IP, server.Port, server.Apikey, instance.ContainerName, instance.SshPasswd)
}()
portBindings, err := GetPortForward(server.IP, server.Port, server.Apikey, instance.ContainerName)
if err != nil {
l.Error("get port forward error: %v", err)
return err
}
instance.SshAddress = server.IP + ":" + portBindings["22"]
instance.TensorBoardAddress = server.IP + ":" + portBindings["6007"]
instance.JupyterAddress = server.IP + ":" + portBindings["8888"]
instance.GrafanaAddress = server.IP + ":" + portBindings["3000"]
instance.CodeServerAddress = server.IP + ":" + portBindings["8080"]
instance.CpuOnly = false
instance.GpuCount = gpuCount
instance.Status = models.InstanceStatusRunning
result = database.DB.Save(&instance)
if result.Error != nil {
l.Error("save instance error: %v", result.Error)
return result.Error
}
return nil
}