diff --git a/api/pod.go b/api/pod.go index 6847957..e267523 100644 --- a/api/pod.go +++ b/api/pod.go @@ -155,10 +155,34 @@ type CreatePodInput struct { Ports string `json:"ports"` SupportPublicIp bool `json:"supportPublicIp"` StartSSH bool `json:"startSsh"` + Spot bool `json:"spot"` TemplateId string `json:"templateId"` VolumeInGb int `json:"volumeInGb"` VolumeMountPath string `json:"volumeMountPath"` } + +type CreateSpotPodInput struct { + BidPerGpu float32 `json:"bidPerGpu"` + CloudType string `json:"cloudType"` + ContainerDiskInGb int `json:"containerDiskInGb"` + DataCenterId string `json:"dataCenterId"` + DockerArgs string `json:"dockerArgs"` + Env []*PodEnv `json:"env"` + GpuCount int `json:"gpuCount"` + GpuTypeId string `json:"gpuTypeId"` + ImageName string `json:"imageName"` + MinMemoryInGb int `json:"minMemoryInGb"` + MinVcpuCount int `json:"minVcpuCount"` + Name string `json:"name"` + NetworkVolumeId string `json:"networkVolumeId"` + Ports string `json:"ports"` + SupportPublicIp bool `json:"supportPublicIp"` + StartSSH bool `json:"startSsh"` + TemplateId string `json:"templateId"` + VolumeInGb int `json:"volumeInGb"` + VolumeMountPath string `json:"volumeMountPath"` +} + type PodEnv struct { Key string `json:"key"` Value string `json:"value"` @@ -219,6 +243,82 @@ func CreatePod(podInput *CreatePodInput) (pod map[string]interface{}, err error) return } +func CreateSpotPod(podInput *CreatePodInput, bidPerGpu float32) (pod map[string]interface{}, err error) { + if podInput.Name == "" { + names := strings.Split(podInput.ImageName, ":") + podInput.Name = names[0] + } + + input := Input{ + Query: ` + mutation podRentInterruptable($input: PodRentInterruptableInput!) { + podRentInterruptable(input: $input) { + id + costPerHr + desiredStatus + lastStatusChange + } + } + `, + Variables: map[string]interface{}{ + "input": map[string]interface{}{ + "bidPerGpu": bidPerGpu, + "cloudType": podInput.CloudType, + "containerDiskInGb": podInput.ContainerDiskInGb, + "dataCenterId": podInput.DataCenterId, + "dockerArgs": podInput.DockerArgs, + "env": podInput.Env, + "gpuCount": podInput.GpuCount, + "gpuTypeId": podInput.GpuTypeId, + "imageName": podInput.ImageName, + "minMemoryInGb": podInput.MinMemoryInGb, + "minVcpuCount": podInput.MinVcpuCount, + "name": podInput.Name, + "networkVolumeId": podInput.NetworkVolumeId, + "ports": podInput.Ports, + "startSsh": podInput.StartSSH, + "templateId": podInput.TemplateId, + "volumeInGb": podInput.VolumeInGb, + "volumeMountPath": podInput.VolumeMountPath, + }, + }, + } + res, err := Query(input) + if err != nil { + return + } + defer res.Body.Close() + rawData, err := io.ReadAll(res.Body) + if err != nil { + return + } + if res.StatusCode != 200 { + err = fmt.Errorf("statuscode %d: %s", res.StatusCode, string(rawData)) + return + } + data := make(map[string]interface{}) + if err = json.Unmarshal(rawData, &data); err != nil { + return + } + gqlErrors, ok := data["errors"].([]interface{}) + if ok && len(gqlErrors) > 0 { + firstErr, _ := gqlErrors[0].(map[string]interface{}) + err = errors.New(firstErr["message"].(string)) + return + } + gqldata, ok := data["data"].(map[string]interface{}) + if !ok || gqldata == nil { + err = fmt.Errorf("data is nil: %s", string(rawData)) + return + } + pod, ok = gqldata["podRentInterruptable"].(map[string]interface{}) + if !ok || pod == nil { + err = fmt.Errorf("pod is nil: %s", string(rawData)) + return + } + return +} + func StopPod(id string) (podStop map[string]interface{}, err error) { input := Input{ Query: ` diff --git a/cmd/pod/createPod.go b/cmd/pod/createPod.go index 249e8c1..08dcd34 100644 --- a/cmd/pod/createPod.go +++ b/cmd/pod/createPod.go @@ -70,11 +70,28 @@ var CreatePodCmd = &cobra.Command{ } else { input.CloudType = "COMMUNITY" } - pod, err := api.CreatePod(input) + + var pod map[string]interface{} + var err error + + // Get the flag value and check if it was explicitly set + bidPerGpuFlag := cmd.Flags().Lookup("bidPerGpu") + if bidPerGpuFlag.Changed { + if bidPerGpu <= 0 { + cobra.CheckErr(fmt.Errorf("bidPerGpu must be greater than 0")) + } + pod, err = api.CreateSpotPod(input, bidPerGpu) + } else { + pod, err = api.CreatePod(input) + } cobra.CheckErr(err) if pod["desiredStatus"] == "RUNNING" { - fmt.Printf(`pod "%s" created for $%.3f / hr`, pod["id"], pod["costPerHr"]) + podType := "pod" + if bidPerGpu > 0 { + podType = "spot pod" + } + fmt.Printf(`%s "%s" created for $%.3f / hr`, podType, pod["id"], pod["costPerHr"]) fmt.Println() } else { cobra.CheckErr(fmt.Errorf(`pod "%s" start failed; status is %s`, args[0], pod["desiredStatus"])) @@ -102,6 +119,7 @@ func init() { CreatePodCmd.Flags().StringVar(&networkVolumeId, "networkVolumeId", "", "network volume id") CreatePodCmd.Flags().StringVar(&dataCenterId, "dataCenterId", "", "datacenter id to create in") CreatePodCmd.Flags().BoolVar(&startSSH, "startSSH", false, "enable SSH login") + CreatePodCmd.Flags().Float32Var(&bidPerGpu, "bidPerGpu", 0, "bid per gpu for spot price (if set, creates a spot pod)") CreatePodCmd.MarkFlagRequired("gpuType") //nolint CreatePodCmd.MarkFlagRequired("imageName") //nolint