Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions api/pod.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,34 @@ type CreatePodInput struct {
Ports string `json:"ports"`
SupportPublicIp bool `json:"supportPublicIp"`
StartSSH bool `json:"startSsh"`
Spot bool `json:"spot"`
TemplateId string `json:"templateId"`
VolumeInGb int `json:"volumeInGb"`
VolumeMountPath string `json:"volumeMountPath"`
}

type CreateSpotPodInput struct {
BidPerGpu float32 `json:"bidPerGpu"`
CloudType string `json:"cloudType"`
ContainerDiskInGb int `json:"containerDiskInGb"`
DataCenterId string `json:"dataCenterId"`
DockerArgs string `json:"dockerArgs"`
Env []*PodEnv `json:"env"`
GpuCount int `json:"gpuCount"`
GpuTypeId string `json:"gpuTypeId"`
ImageName string `json:"imageName"`
MinMemoryInGb int `json:"minMemoryInGb"`
MinVcpuCount int `json:"minVcpuCount"`
Name string `json:"name"`
NetworkVolumeId string `json:"networkVolumeId"`
Ports string `json:"ports"`
SupportPublicIp bool `json:"supportPublicIp"`
StartSSH bool `json:"startSsh"`
TemplateId string `json:"templateId"`
VolumeInGb int `json:"volumeInGb"`
VolumeMountPath string `json:"volumeMountPath"`
}

type PodEnv struct {
Key string `json:"key"`
Value string `json:"value"`
Expand Down Expand Up @@ -219,6 +243,82 @@ func CreatePod(podInput *CreatePodInput) (pod map[string]interface{}, err error)
return
}

func CreateSpotPod(podInput *CreatePodInput, bidPerGpu float32) (pod map[string]interface{}, err error) {
if podInput.Name == "" {
names := strings.Split(podInput.ImageName, ":")
podInput.Name = names[0]
}

input := Input{
Query: `
mutation podRentInterruptable($input: PodRentInterruptableInput!) {
podRentInterruptable(input: $input) {
id
costPerHr
desiredStatus
lastStatusChange
}
}
`,
Variables: map[string]interface{}{
"input": map[string]interface{}{
"bidPerGpu": bidPerGpu,
"cloudType": podInput.CloudType,
"containerDiskInGb": podInput.ContainerDiskInGb,
"dataCenterId": podInput.DataCenterId,
"dockerArgs": podInput.DockerArgs,
"env": podInput.Env,
"gpuCount": podInput.GpuCount,
"gpuTypeId": podInput.GpuTypeId,
"imageName": podInput.ImageName,
"minMemoryInGb": podInput.MinMemoryInGb,
"minVcpuCount": podInput.MinVcpuCount,
"name": podInput.Name,
"networkVolumeId": podInput.NetworkVolumeId,
"ports": podInput.Ports,
"startSsh": podInput.StartSSH,
"templateId": podInput.TemplateId,
"volumeInGb": podInput.VolumeInGb,
"volumeMountPath": podInput.VolumeMountPath,
},
},
}
res, err := Query(input)
if err != nil {
return
}
defer res.Body.Close()
rawData, err := io.ReadAll(res.Body)
if err != nil {
return
}
if res.StatusCode != 200 {
err = fmt.Errorf("statuscode %d: %s", res.StatusCode, string(rawData))
return
}
data := make(map[string]interface{})
if err = json.Unmarshal(rawData, &data); err != nil {
return
}
gqlErrors, ok := data["errors"].([]interface{})
if ok && len(gqlErrors) > 0 {
firstErr, _ := gqlErrors[0].(map[string]interface{})
err = errors.New(firstErr["message"].(string))
return
}
gqldata, ok := data["data"].(map[string]interface{})
if !ok || gqldata == nil {
err = fmt.Errorf("data is nil: %s", string(rawData))
return
}
pod, ok = gqldata["podRentInterruptable"].(map[string]interface{})
if !ok || pod == nil {
err = fmt.Errorf("pod is nil: %s", string(rawData))
return
}
return
}

func StopPod(id string) (podStop map[string]interface{}, err error) {
input := Input{
Query: `
Expand Down
22 changes: 20 additions & 2 deletions cmd/pod/createPod.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,28 @@ var CreatePodCmd = &cobra.Command{
} else {
input.CloudType = "COMMUNITY"
}
pod, err := api.CreatePod(input)

var pod map[string]interface{}
var err error

// Get the flag value and check if it was explicitly set
bidPerGpuFlag := cmd.Flags().Lookup("bidPerGpu")
if bidPerGpuFlag.Changed {
if bidPerGpu <= 0 {
cobra.CheckErr(fmt.Errorf("bidPerGpu must be greater than 0"))
}
pod, err = api.CreateSpotPod(input, bidPerGpu)
} else {
pod, err = api.CreatePod(input)
}
cobra.CheckErr(err)

if pod["desiredStatus"] == "RUNNING" {
fmt.Printf(`pod "%s" created for $%.3f / hr`, pod["id"], pod["costPerHr"])
podType := "pod"
if bidPerGpu > 0 {
podType = "spot pod"
}
fmt.Printf(`%s "%s" created for $%.3f / hr`, podType, pod["id"], pod["costPerHr"])
fmt.Println()
} else {
cobra.CheckErr(fmt.Errorf(`pod "%s" start failed; status is %s`, args[0], pod["desiredStatus"]))
Expand Down Expand Up @@ -102,6 +119,7 @@ func init() {
CreatePodCmd.Flags().StringVar(&networkVolumeId, "networkVolumeId", "", "network volume id")
CreatePodCmd.Flags().StringVar(&dataCenterId, "dataCenterId", "", "datacenter id to create in")
CreatePodCmd.Flags().BoolVar(&startSSH, "startSSH", false, "enable SSH login")
CreatePodCmd.Flags().Float32Var(&bidPerGpu, "bidPerGpu", 0, "bid per gpu for spot price (if set, creates a spot pod)")

CreatePodCmd.MarkFlagRequired("gpuType") //nolint
CreatePodCmd.MarkFlagRequired("imageName") //nolint
Expand Down