Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
861 changes: 747 additions & 114 deletions sdk/go/README.md

Large diffs are not rendered by default.

289 changes: 253 additions & 36 deletions sdk/go/dstack/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,21 @@ package dstack
import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/sha512"
"crypto/x509"
"encoding/hex"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"os"
"strings"
"time"
)

// Represents the response from a TLS key derivation request.
Expand All @@ -27,20 +32,83 @@ type GetTlsKeyResponse struct {
CertificateChain []string `json:"certificate_chain"`
}

// AsUint8Array converts the private key to bytes, optionally limiting the length
func (r *GetTlsKeyResponse) AsUint8Array(maxLength ...int) ([]byte, error) {
block, _ := pem.Decode([]byte(r.Key))
if block == nil {
return nil, fmt.Errorf("failed to decode pem private key")
}

key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err)
}

var keyBytes []byte
switch k := key.(type) {
case *ecdsa.PrivateKey:
keyBytes = k.D.FillBytes(make([]byte, (k.Curve.Params().N.BitLen()+7)/8))
case ed25519.PrivateKey:
keyBytes = k.Seed()
default:
return nil, fmt.Errorf("unsupported key type: %T", key)
}

if len(maxLength) > 0 && maxLength[0] > 0 && maxLength[0] < len(keyBytes) {
return keyBytes[:maxLength[0]], nil
}
return keyBytes, nil
}

// Represents the response from a key derivation request.
type GetKeyResponse struct {
Key string `json:"key"`
SignatureChain []string `json:"signature_chain"`
}

// DecodeKey returns the key as bytes
func (r *GetKeyResponse) DecodeKey() ([]byte, error) {
return hex.DecodeString(r.Key)
}

// DecodeSignatureChain returns the signature chain as bytes
func (r *GetKeyResponse) DecodeSignatureChain() ([][]byte, error) {
result := make([][]byte, len(r.SignatureChain))
for i, sig := range r.SignatureChain {
bytes, err := hex.DecodeString(sig)
if err != nil {
return nil, fmt.Errorf("failed to decode signature %d: %w", i, err)
}
result[i] = bytes
}
return result, nil
}

// Represents the response from a quote request.
type GetQuoteResponse struct {
Quote []byte `json:"quote"`
Quote string `json:"quote"`
EventLog string `json:"event_log"`
ReportData []byte `json:"report_data"`
ReportData string `json:"report_data"`
VmConfig string `json:"vm_config"`
}

// DecodeQuote returns the quote bytes
func (r *GetQuoteResponse) DecodeQuote() ([]byte, error) {
return hex.DecodeString(r.Quote)
}

// DecodeReportData returns the report data bytes
func (r *GetQuoteResponse) DecodeReportData() ([]byte, error) {
return hex.DecodeString(r.ReportData)
}

// DecodeEventLog returns the event log as structured data
func (r *GetQuoteResponse) DecodeEventLog() ([]EventLog, error) {
var events []EventLog
err := json.Unmarshal([]byte(r.EventLog), &events)
return events, err
}

// Represents the response from an attestation request.
type AttestResponse struct {
Attestation []byte
Expand All @@ -57,17 +125,20 @@ type EventLog struct {

// Represents the TCB information
type TcbInfo struct {
Mrtd string `json:"mrtd"`
Rtmr0 string `json:"rtmr0"`
Rtmr1 string `json:"rtmr1"`
Rtmr2 string `json:"rtmr2"`
Rtmr3 string `json:"rtmr3"`
// The hash of the OS image. This is empty if the OS image is not measured by KMS.
OsImageHash string `json:"os_image_hash,omitempty"`
ComposeHash string `json:"compose_hash"`
DeviceID string `json:"device_id"`
AppCompose string `json:"app_compose"`
EventLog []EventLog `json:"event_log"`
Mrtd string `json:"mrtd"`
Rtmr0 string `json:"rtmr0"`
Rtmr1 string `json:"rtmr1"`
Rtmr2 string `json:"rtmr2"`
Rtmr3 string `json:"rtmr3"`
AppCompose string `json:"app_compose"`
EventLog []EventLog `json:"event_log"`
// V0.3.x fields
RootfsHash string `json:"rootfs_hash,omitempty"`
// V0.5.x fields
MrAggregated string `json:"mr_aggregated,omitempty"`
OsImageHash string `json:"os_image_hash,omitempty"`
ComposeHash string `json:"compose_hash,omitempty"`
DeviceID string `json:"device_id,omitempty"`
}

// Represents the response from an info request
Expand All @@ -81,9 +152,11 @@ type InfoResponse struct {
MrAggregated string `json:"mr_aggregated,omitempty"`
KeyProviderInfo string `json:"key_provider_info"`
// Optional: empty if OS image is not measured by KMS
OsImageHash string `json:"os_image_hash,omitempty"`
ComposeHash string `json:"compose_hash"`
VmConfig string `json:"vm_config,omitempty"`
OsImageHash string `json:"os_image_hash,omitempty"`
ComposeHash string `json:"compose_hash"`
VmConfig string `json:"vm_config,omitempty"`
CloudVendor string `json:"cloud_vendor,omitempty"`
CloudProduct string `json:"cloud_product,omitempty"`
}

// DecodeTcbInfo decodes the TcbInfo string into a TcbInfo struct
Expand Down Expand Up @@ -269,6 +342,7 @@ func (c *DstackClient) sendRPCRequest(ctx context.Context, path string, payload
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "dstack-sdk-go/0.1.0")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
Expand Down Expand Up @@ -297,6 +371,9 @@ type tlsKeyOptions struct {
usageRaTls bool
usageServerAuth bool
usageClientAuth bool
notBefore *uint64
notAfter *uint64
withAppInfo *bool
}

// WithSubject sets the subject for the TLS key
Expand Down Expand Up @@ -334,6 +411,27 @@ func WithUsageClientAuth(usage bool) TlsKeyOption {
}
}

// WithNotBefore sets the not_before timestamp for the certificate
func WithNotBefore(t uint64) TlsKeyOption {
return func(opts *tlsKeyOptions) {
opts.notBefore = &t
}
}

// WithNotAfter sets the not_after timestamp for the certificate
func WithNotAfter(t uint64) TlsKeyOption {
return func(opts *tlsKeyOptions) {
opts.notAfter = &t
}
}

// WithAppInfo sets the with_app_info flag for the certificate
func WithAppInfo(enabled bool) TlsKeyOption {
return func(opts *tlsKeyOptions) {
opts.withAppInfo = &enabled
}
}

// Gets a TLS key from the dstack service with optional parameters.
func (c *DstackClient) GetTlsKey(
ctx context.Context,
Expand All @@ -356,6 +454,15 @@ func (c *DstackClient) GetTlsKey(
if len(opts.altNames) > 0 {
payload["alt_names"] = opts.altNames
}
if opts.notBefore != nil {
payload["not_before"] = *opts.notBefore
}
if opts.notAfter != nil {
payload["not_after"] = *opts.notAfter
}
if opts.withAppInfo != nil {
payload["with_app_info"] = *opts.withAppInfo
}

data, err := c.sendRPCRequest(ctx, "/GetTlsKey", payload)
if err != nil {
Expand Down Expand Up @@ -429,30 +536,12 @@ func (c *DstackClient) GetQuote(ctx context.Context, reportData []byte) (*GetQuo
return nil, err
}

var response struct {
Quote string `json:"quote"`
EventLog string `json:"event_log"`
ReportData string `json:"report_data"`
}
var response GetQuoteResponse
if err := json.Unmarshal(data, &response); err != nil {
return nil, err
}

quote, err := hex.DecodeString(response.Quote)
if err != nil {
return nil, err
}

reportDataBytes, err := hex.DecodeString(response.ReportData)
if err != nil {
return nil, err
}

return &GetQuoteResponse{
Quote: quote,
EventLog: response.EventLog,
ReportData: reportDataBytes,
}, nil
return &response, nil
}

// Gets a versioned attestation from the dstack service.
Expand Down Expand Up @@ -600,14 +689,142 @@ func (c *DstackClient) Verify(ctx context.Context, algorithm string, data []byte
return &response, nil
}

// IsReachable checks if the service is reachable
func (c *DstackClient) IsReachable(ctx context.Context) bool {
ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
defer cancel()
_, err := c.Info(ctx)
return err == nil
}

// EmitEvent sends an event to be extended to RTMR3 on TDX platform.
// The event will be extended to RTMR3 with the provided name and payload.
//
// Requires dstack OS 0.5.0 or later.
func (c *DstackClient) EmitEvent(ctx context.Context, event string, payload []byte) error {
if event == "" {
return fmt.Errorf("event name cannot be empty")
}
_, err := c.sendRPCRequest(ctx, "/EmitEvent", map[string]interface{}{
"event": event,
"payload": hex.EncodeToString(payload),
})
return err
}

// Legacy methods for backward compatibility with warnings

// DeriveKey is deprecated. Use GetKey instead.
// Deprecated: Use GetKey instead.
func (c *DstackClient) DeriveKey(path string, subject string, altNames []string) (*GetTlsKeyResponse, error) {
return nil, fmt.Errorf("deriveKey is deprecated, please use GetKey instead")
}

// TdxQuote is deprecated. Use GetQuote instead.
// Deprecated: Use GetQuote instead.
func (c *DstackClient) TdxQuote(ctx context.Context, reportData []byte, hashAlgorithm string) (*GetQuoteResponse, error) {
c.logger.Warn("tdxQuote is deprecated, please use GetQuote instead")
if hashAlgorithm != "raw" {
return nil, fmt.Errorf("tdxQuote only supports raw hash algorithm")
}
return c.GetQuote(ctx, reportData)
}

// TappdClient is a deprecated wrapper around DstackClient for backward compatibility.
// Deprecated: Use DstackClient instead.
type TappdClient struct {
*DstackClient
}

// NewTappdClient creates a new deprecated TappdClient.
// Deprecated: Use NewDstackClient instead.
func NewTappdClient(opts ...DstackClientOption) *TappdClient {
// Create a modified option to use TAPPD_SIMULATOR_ENDPOINT
tappdOpts := make([]DstackClientOption, 0, len(opts)+1)

// Add default endpoint option that checks TAPPD_SIMULATOR_ENDPOINT
tappdOpts = append(tappdOpts, func(c *DstackClient) {
if c.endpoint == "" {
if simEndpoint, exists := os.LookupEnv("TAPPD_SIMULATOR_ENDPOINT"); exists {
c.logger.Warn("Using tappd endpoint", "endpoint", simEndpoint)
c.endpoint = simEndpoint
} else {
c.endpoint = "/var/run/tappd.sock"
}
}
})

// Add user-provided options
tappdOpts = append(tappdOpts, opts...)

client := NewDstackClient(tappdOpts...)
client.logger.Warn("TappdClient is deprecated, please use DstackClient instead")

return &TappdClient{
DstackClient: client,
}
}

// Override deprecated methods to use proper tappd RPC paths

// DeriveKey is deprecated. Use GetKey instead.
// Deprecated: Use GetKey instead.
func (tc *TappdClient) DeriveKey(ctx context.Context, path string, subject string, altNames []string) (*GetTlsKeyResponse, error) {
tc.logger.Warn("deriveKey is deprecated, please use GetKey instead")

if subject == "" {
subject = path
}

payload := map[string]interface{}{
"path": path,
"subject": subject,
}
if len(altNames) > 0 {
payload["alt_names"] = altNames
}

data, err := tc.sendRPCRequest(ctx, "/prpc/Tappd.DeriveKey", payload)
if err != nil {
return nil, err
}

var response GetTlsKeyResponse
if err := json.Unmarshal(data, &response); err != nil {
return nil, err
}
return &response, nil
}

// TdxQuote is deprecated. Use GetQuote instead.
// Deprecated: Use GetQuote instead.
func (tc *TappdClient) TdxQuote(ctx context.Context, reportData []byte, hashAlgorithm string) (*GetQuoteResponse, error) {
tc.logger.Warn("tdxQuote is deprecated, please use GetQuote instead")

if hashAlgorithm == "raw" {
if len(reportData) > 64 {
return nil, fmt.Errorf("report data is too large, it should be at most 64 bytes when hashAlgorithm is raw")
}
if len(reportData) < 64 {
// Left-pad with zeros
padding := make([]byte, 64-len(reportData))
reportData = append(padding, reportData...)
}
}

payload := map[string]interface{}{
"report_data": hex.EncodeToString(reportData),
"hash_algorithm": hashAlgorithm,
}

data, err := tc.sendRPCRequest(ctx, "/prpc/Tappd.TdxQuote", payload)
if err != nil {
return nil, err
}

var response GetQuoteResponse
if err := json.Unmarshal(data, &response); err != nil {
return nil, err
}
return &response, nil
}
Loading