4.5 KB
server.go
package digitalocean

import (
	"fmt"
	"strings"
	"time"

	"github.com/digitalocean/godo"
	"github.com/readysite/readysite/pkg/platform"
)

// CreateServer creates a DigitalOcean droplet
func (b *backend) CreateServer(opts platform.ServerOptions) (*platform.Server, error) {
	// Translate region and size
	region, ok := regions[opts.Region]
	if !ok {
		return nil, fmt.Errorf("%w: %s", platform.ErrUnsupportedRegion, opts.Region)
	}
	size, ok := sizes[opts.Size]
	if !ok {
		return nil, fmt.Errorf("%w: %s", platform.ErrUnsupportedSize, opts.Size)
	}

	// Handle SSH key (expects fingerprint from GetSSHKeyFingerprint)
	var sshKeyIDs []godo.DropletCreateSSHKey
	if opts.SSHKey != "" {
		sshKeyIDs = append(sshKeyIDs, godo.DropletCreateSSHKey{Fingerprint: opts.SSHKey})
	}

	createReq := &godo.DropletCreateRequest{
		Name:    opts.Name,
		Region:  region,
		Size:    size,
		Image:   godo.DropletCreateImage{Slug: opts.Image},
		SSHKeys: sshKeyIDs,
		Tags:    opts.Tags,
	}

	droplet, _, err := b.client.Droplets.Create(b.ctx, createReq)
	if err != nil {
		return nil, fmt.Errorf("create droplet: %w", err)
	}

	// Wait for droplet to be active and get IP
	server, err := b.waitForDroplet(droplet.ID)
	if err != nil {
		return nil, err
	}

	return server, nil
}

// GetServer retrieves a droplet by name
func (b *backend) GetServer(name string) (*platform.Server, error) {
	droplets, _, err := b.client.Droplets.ListByName(b.ctx, name, nil)
	if err != nil {
		return nil, fmt.Errorf("list droplets: %w", err)
	}

	for _, d := range droplets {
		if d.Name == name {
			return b.dropletToServer(&d), nil
		}
	}

	return nil, platform.ErrNotFound
}

// DeleteServer destroys a droplet
func (b *backend) DeleteServer(id string) error {
	// DigitalOcean uses int IDs
	var dropletID int
	if _, err := fmt.Sscanf(id, "%d", &dropletID); err != nil {
		return fmt.Errorf("invalid droplet ID %q: %w", id, err)
	}

	_, err := b.client.Droplets.Delete(b.ctx, dropletID)
	if err != nil {
		if strings.Contains(err.Error(), "not found") {
			return platform.ErrNotFound
		}
		return fmt.Errorf("delete droplet: %w", err)
	}
	return nil
}

// Helper methods

func (b *backend) waitForDroplet(id int) (*platform.Server, error) {
	timeout := time.After(5 * time.Minute)
	ticker := time.NewTicker(5 * time.Second)
	defer ticker.Stop()

	for {
		select {
		case <-timeout:
			return nil, platform.ErrTimeout
		case <-ticker.C:
			droplet, _, err := b.client.Droplets.Get(b.ctx, id)
			if err != nil {
				continue
			}
			if droplet.Status == "active" {
				return b.dropletToServer(droplet), nil
			}
		}
	}
}

func (b *backend) dropletToServer(d *godo.Droplet) *platform.Server {
	ip := ""
	for _, network := range d.Networks.V4 {
		if network.Type == "public" {
			ip = network.IPAddress
			break
		}
	}

	return &platform.Server{
		ID:     fmt.Sprintf("%d", d.ID),
		Name:   d.Name,
		IP:     ip,
		Size:   d.Size.Slug,
		Region: d.Region.Slug,
		Status: d.Status,
	}
}

// GetSSHKeyFingerprint finds or registers an SSH key, returns its fingerprint.
func (b *backend) GetSSHKeyFingerprint(publicKey string) (string, error) {
	// List existing keys
	keys, _, err := b.client.Keys.List(b.ctx, nil)
	if err != nil {
		return "", err
	}

	// Compare just the key type and data (not the comment)
	// Trim whitespace to handle trailing newlines
	publicKey = strings.TrimSpace(publicKey)
	keyParts := strings.Fields(publicKey)
	keyContent := ""
	if len(keyParts) >= 2 {
		keyContent = keyParts[0] + " " + keyParts[1]
	}

	for _, k := range keys {
		existingParts := strings.Fields(strings.TrimSpace(k.PublicKey))
		if len(existingParts) >= 2 {
			existingContent := existingParts[0] + " " + existingParts[1]
			if existingContent == keyContent {
				return k.Fingerprint, nil
			}
		}
	}

	// Create new key
	key, _, err := b.client.Keys.Create(b.ctx, &godo.KeyCreateRequest{
		Name:      fmt.Sprintf("readysite-%d", time.Now().Unix()),
		PublicKey: publicKey,
	})
	if err != nil {
		// If key already in use, it means the content matched but we didn't find it
		// This can happen if DO returned paginated results - list all pages
		if strings.Contains(err.Error(), "already in use") {
			keys, _, _ = b.client.Keys.List(b.ctx, &godo.ListOptions{PerPage: 200})
			for _, k := range keys {
				existingParts := strings.Fields(strings.TrimSpace(k.PublicKey))
				if len(existingParts) >= 2 {
					existingContent := existingParts[0] + " " + existingParts[1]
					if existingContent == keyContent {
						return k.Fingerprint, nil
					}
				}
			}
		}
		return "", fmt.Errorf("create SSH key: %w", err)
	}

	return key.Fingerprint, nil
}
← Back