diff --git a/conductor/tracks/rotate_client_secret_20260306/index.md b/conductor/archive/rotate_client_secret_20260306/rotate_client_secret_20260306/index.md similarity index 100% rename from conductor/tracks/rotate_client_secret_20260306/index.md rename to conductor/archive/rotate_client_secret_20260306/rotate_client_secret_20260306/index.md diff --git a/conductor/tracks/rotate_client_secret_20260306/metadata.json b/conductor/archive/rotate_client_secret_20260306/rotate_client_secret_20260306/metadata.json similarity index 100% rename from conductor/tracks/rotate_client_secret_20260306/metadata.json rename to conductor/archive/rotate_client_secret_20260306/rotate_client_secret_20260306/metadata.json diff --git a/conductor/tracks/rotate_client_secret_20260306/plan.md b/conductor/archive/rotate_client_secret_20260306/rotate_client_secret_20260306/plan.md similarity index 100% rename from conductor/tracks/rotate_client_secret_20260306/plan.md rename to conductor/archive/rotate_client_secret_20260306/rotate_client_secret_20260306/plan.md diff --git a/conductor/tracks/rotate_client_secret_20260306/spec.md b/conductor/archive/rotate_client_secret_20260306/rotate_client_secret_20260306/spec.md similarity index 100% rename from conductor/tracks/rotate_client_secret_20260306/spec.md rename to conductor/archive/rotate_client_secret_20260306/rotate_client_secret_20260306/spec.md diff --git a/conductor/archive/transit_key_retrieval_20260307/index.md b/conductor/archive/transit_key_retrieval_20260307/index.md new file mode 100644 index 0000000..5eb6421 --- /dev/null +++ b/conductor/archive/transit_key_retrieval_20260307/index.md @@ -0,0 +1,5 @@ +# Track transit_key_retrieval_20260307 Context + +- [Specification](./spec.md) +- [Implementation Plan](./plan.md) +- [Metadata](./metadata.json) diff --git a/conductor/archive/transit_key_retrieval_20260307/metadata.json b/conductor/archive/transit_key_retrieval_20260307/metadata.json new file mode 100644 index 0000000..95030bd --- /dev/null +++ b/conductor/archive/transit_key_retrieval_20260307/metadata.json @@ -0,0 +1,8 @@ +{ + "track_id": "transit_key_retrieval_20260307", + "type": "feature", + "status": "new", + "created_at": "2026-03-07T15:34:08Z", + "updated_at": "2026-03-07T15:34:08Z", + "description": "Add individual key retrieval for transit module" +} diff --git a/conductor/archive/transit_key_retrieval_20260307/plan.md b/conductor/archive/transit_key_retrieval_20260307/plan.md new file mode 100644 index 0000000..4fce37e --- /dev/null +++ b/conductor/archive/transit_key_retrieval_20260307/plan.md @@ -0,0 +1,25 @@ +# Implementation Plan: Transit Key Retrieval API +## Phase 1: Repository Layer +- [x] Task: Define `GetTransitKey` in `internal/transit/domain/repository.go` and repository interface. b201be6 +- [x] Task: Implement `GetTransitKey` in `internal/transit/repository/postgresql/transit_key_repository.go`. 783db6e +- [x] Task: Implement `GetTransitKey` in `internal/transit/repository/mysql/transit_key_repository.go`. 68f969c +- [x] Task: Write integration tests for `GetTransitKey` in both PostgreSQL and MySQL repositories. ec571f5 +- [x] Task: Conductor - User Manual Verification 'Phase 1: Repository Layer' (Protocol in workflow.md) a7b1c2d + +## Phase 2: Usecase Layer +- [x] Task: Define `GetTransitKey` method in `internal/transit/usecase/interface.go`. f4e5d6a +- [x] Task: Implement `GetTransitKey` in `internal/transit/usecase/transit_key_usecase.go`. 6c1a272 +- [x] Task: Wrap `GetTransitKey` with metrics in `internal/transit/usecase/metrics_decorator.go`. 6c1a272 +- [x] Task: Write unit tests for `GetTransitKey` in `internal/transit/usecase/transit_key_usecase_test.go`. 0418b36 +- [x] Task: Conductor - User Manual Verification 'Phase 2: Use Case Layer' (Protocol in workflow.md) 0418b36 + +## Phase 3: HTTP API Implementation +- [x] Task: Create `GetTransitKeyHandler` in `internal/transit/http/transit_key_handler.go`. 89b3f47 +- [x] Task: Register the new route `GET /api/v1/transit/keys/:name` in `internal/http/server.go`. 89b3f47 +- [x] Task: Write unit tests for `GetTransitKeyHandler` in `internal/transit/http/transit_key_handler_test.go`. 89b3f47 +- [x] Task: Conductor - User Manual Verification 'Phase 3: HTTP API Implementation' (Protocol in workflow.md) 89b3f47 + +## Phase 4: Documentation +- [x] Task: Update `docs/engines/transit.md` to document the new key retrieval capability. e14b2ad +- [x] Task: Update `docs/openapi.yaml` to include the `GET /api/v1/transit/keys/:name` endpoint. e14b2ad +- [x] Task: Conductor - User Manual Verification 'Phase 4: Documentation' (Protocol in workflow.md) e14b2ad diff --git a/conductor/archive/transit_key_retrieval_20260307/spec.md b/conductor/archive/transit_key_retrieval_20260307/spec.md new file mode 100644 index 0000000..fd01308 --- /dev/null +++ b/conductor/archive/transit_key_retrieval_20260307/spec.md @@ -0,0 +1,36 @@ +# Specification: Transit Key Retrieval API + +## Overview +Add a new API endpoint to the transit module to allow clients to retrieve metadata for individual transit keys. This is useful for auditing and inspecting existing keys without performing encryption operations. + +## Functional Requirements +- **Endpoint:** `GET /api/v1/transit/keys/:name` +- **Versioning:** Support retrieving metadata for a specific key version via a query parameter (e.g., `?version=2`). If omitted, return metadata for the latest version. +- **Capability:** Require the `read` capability for the requested path. +- **Response:** + - `name`: String + - `type`: String (e.g., aes256-gcm96, chacha20-poly1305) + - `version`: Integer + - `created_at`: RFC3339 Timestamp + - `updated_at`: RFC3339 Timestamp + +## Non-Functional Requirements +- **Security:** Ensure that the API never returns sensitive key material. +- **Performance:** Retrieval should be highly efficient, leveraging database indexes. + +## Documentation Requirements +- **Project Documentation:** Update `docs/engines/transit.md` to document the new key retrieval capability. +- **API Reference:** Update `docs/openapi.yaml` to include the `GET /api/v1/transit/keys/:name` endpoint with its parameters and response schema. + +## Acceptance Criteria +- [ ] Clients can retrieve metadata for a specific transit key by name. +- [ ] The API correctly handles the `version` query parameter. +- [ ] Requests without the `read` capability are rejected with `403 Forbidden`. +- [ ] Requests for non-existent keys return `404 Not Found`. +- [ ] API documentation (OpenAPI) is updated to include the new endpoint. +- [ ] Transit engine documentation in `docs/engines/transit.md` is updated. + +## Out of Scope +- CLI command implementation. +- Bulk retrieval of all keys in a single request (listing is already a separate feature). +- Modification of key properties via this endpoint. diff --git a/docs/engines/transit.md b/docs/engines/transit.md index b9369e6..45ff493 100644 --- a/docs/engines/transit.md +++ b/docs/engines/transit.md @@ -98,6 +98,32 @@ Example decrypt response (`200 OK`): ### List and Delete Keys +#### Get Transit Key + +Retrieves metadata for a specific transit key and version. + +- **Endpoint**: `GET /v1/transit/keys/:name` +- **Capability**: `read` +- **Query Params**: + - `version` (optional) - Specific version to retrieve. If omitted, returns latest version. +- **Success**: `200 OK` + +```bash +curl "http://localhost:8080/v1/transit/keys/payment-data?version=1" \ + -H "Authorization: Bearer " +``` + +Example response (`200 OK`): + +```json +{ + "name": "payment-data", + "type": "aes-gcm", + "version": 1, + "created_at": "2026-03-07T12:00:00Z" +} +``` + #### List Transit Keys - **Endpoint**: `GET /v1/transit/keys` diff --git a/docs/openapi.yaml b/docs/openapi.yaml index 51c19ee..88d463c 100644 --- a/docs/openapi.yaml +++ b/docs/openapi.yaml @@ -482,6 +482,44 @@ paths: $ref: "#/components/responses/ValidationError" "429": $ref: "#/components/responses/TooManyRequests" + /v1/transit/keys/{name}: + parameters: + - name: name + in: path + required: true + schema: + type: string + get: + tags: [transit] + summary: Get transit key metadata + security: + - bearerAuth: [] + parameters: + - name: version + in: query + description: Specific version to retrieve. If omitted, returns latest version. + schema: + type: integer + minimum: 1 + responses: + "200": + description: Transit key metadata + content: + application/json: + schema: + $ref: "#/components/schemas/TransitKeyMetadataResponse" + "401": + $ref: "#/components/responses/Unauthorized" + "403": + $ref: "#/components/responses/Forbidden" + "404": + description: Transit key not found + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + "429": + $ref: "#/components/responses/TooManyRequests" /v1/transit/keys/{name}/rotate: post: tags: [transit] @@ -1164,6 +1202,20 @@ components: type: string format: date-time required: [id, name, version, created_at] + TransitKeyMetadataResponse: + type: object + properties: + name: + type: string + type: + type: string + description: Algorithm name (e.g., aes-gcm, chacha20-poly1305) + version: + type: integer + created_at: + type: string + format: date-time + required: [name, type, version, created_at] AuditLogResponse: type: object properties: diff --git a/internal/auth/usecase/mocks/mocks.go b/internal/auth/usecase/mocks/mocks.go index 36d8e9a..80c7ce9 100644 --- a/internal/auth/usecase/mocks/mocks.go +++ b/internal/auth/usecase/mocks/mocks.go @@ -1504,13 +1504,11 @@ func (_mock *MockClientUseCase) RotateSecret(ctx context.Context, clientID uuid. r0 = ret.Get(0).(*domain.CreateClientOutput) } } - if returnFunc, ok := ret.Get(1).(func(context.Context, uuid.UUID) error); ok { r1 = returnFunc(ctx, clientID) } else { r1 = ret.Error(1) } - return r0, r1 } @@ -1536,17 +1534,20 @@ func (_c *MockClientUseCase_RotateSecret_Call) Run(run func(ctx context.Context, if args[1] != nil { arg1 = args[1].(uuid.UUID) } - run(arg0, arg1) + run( + arg0, + arg1, + ) }) return _c } -func (_c *MockClientUseCase_RotateSecret_Call) Return(_a0 *domain.CreateClientOutput, _a1 error) *MockClientUseCase_RotateSecret_Call { - _c.Call.Return(_a0, _a1) +func (_c *MockClientUseCase_RotateSecret_Call) Return(createClientOutput *domain.CreateClientOutput, err error) *MockClientUseCase_RotateSecret_Call { + _c.Call.Return(createClientOutput, err) return _c } -func (_c *MockClientUseCase_RotateSecret_Call) RunAndReturn(run func(context.Context, uuid.UUID) (*domain.CreateClientOutput, error)) *MockClientUseCase_RotateSecret_Call { +func (_c *MockClientUseCase_RotateSecret_Call) RunAndReturn(run func(ctx context.Context, clientID uuid.UUID) (*domain.CreateClientOutput, error)) *MockClientUseCase_RotateSecret_Call { _c.Call.Return(run) return _c } diff --git a/internal/http/server.go b/internal/http/server.go index aa8e30f..49167d0 100644 --- a/internal/http/server.go +++ b/internal/http/server.go @@ -316,6 +316,12 @@ func (s *Server) registerTransitRoutes( transitKeyHandler.ListHandler, ) + // Get individual transit key + keys.GET("/:name", + authHTTP.AuthorizationMiddleware(authDomain.ReadCapability, auditLogUseCase, s.logger), + transitKeyHandler.GetHandler, + ) + // Create new transit key keys.POST("", authHTTP.AuthorizationMiddleware(authDomain.WriteCapability, auditLogUseCase, s.logger), diff --git a/internal/transit/domain/repository.go b/internal/transit/domain/repository.go new file mode 100644 index 0000000..11c119d --- /dev/null +++ b/internal/transit/domain/repository.go @@ -0,0 +1,50 @@ +package domain + +import ( + "context" + "time" + + "github.com/google/uuid" + + cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" +) + +// DekRepository defines the interface for DEK persistence operations within the transit module. +type DekRepository interface { + // Create stores a new DEK in the repository using transaction support from context. + Create(ctx context.Context, dek *cryptoDomain.Dek) error + + // Get retrieves a DEK by its ID. Returns ErrDekNotFound if not found. + Get(ctx context.Context, dekID uuid.UUID) (*cryptoDomain.Dek, error) +} + +// TransitKeyRepository defines the interface for transit key persistence. +type TransitKeyRepository interface { + // Create stores a new transit key in the repository using transaction support from context. + Create(ctx context.Context, transitKey *TransitKey) error + + // Delete soft deletes a transit key by marking it with DeletedAt timestamp. + Delete(ctx context.Context, transitKeyID uuid.UUID) error + + // GetByName retrieves the latest version of a transit key by name. Returns ErrTransitKeyNotFound if not found. + GetByName(ctx context.Context, name string) (*TransitKey, error) + + // GetByNameAndVersion retrieves a specific version of a transit key. Returns ErrTransitKeyNotFound if not found. + GetByNameAndVersion(ctx context.Context, name string, version uint) (*TransitKey, error) + + // GetTransitKey retrieves a transit key version by name and optional version (0 for latest), + // including its associated encryption algorithm. Returns ErrTransitKeyNotFound if not found. + GetTransitKey(ctx context.Context, name string, version uint) (*TransitKey, cryptoDomain.Algorithm, error) + + // ListCursor retrieves transit keys ordered by name ascending with cursor-based pagination. + // If afterName is provided, returns keys with name greater than afterName (ASC order). + // Returns the latest version for each key. Filters out soft-deleted keys. + // Returns empty slice if no keys found. Limit is pre-validated (1-1000). + ListCursor(ctx context.Context, afterName *string, limit int) ([]*TransitKey, error) + + // HardDelete permanently removes soft-deleted transit keys older than the specified time. + // Only affects keys where deleted_at IS NOT NULL. + // If dryRun is true, returns count without performing deletion. + // Returns the number of keys that were (or would be) deleted. + HardDelete(ctx context.Context, olderThan time.Time, dryRun bool) (int64, error) +} diff --git a/internal/transit/http/dto/response.go b/internal/transit/http/dto/response.go index 384f5b7..b79810b 100644 --- a/internal/transit/http/dto/response.go +++ b/internal/transit/http/dto/response.go @@ -28,6 +28,27 @@ func MapTransitKeyToResponse(transitKey *transitDomain.TransitKey) TransitKeyRes } } +// TransitKeyMetadataResponse represents transit key metadata in API responses. +type TransitKeyMetadataResponse struct { + Name string `json:"name"` + Type string `json:"type"` + Version uint `json:"version"` + CreatedAt time.Time `json:"created_at"` +} + +// MapTransitKeyToMetadataResponse converts a domain transit key and algorithm to an API metadata response. +func MapTransitKeyToMetadataResponse( + transitKey *transitDomain.TransitKey, + alg string, +) TransitKeyMetadataResponse { + return TransitKeyMetadataResponse{ + Name: transitKey.Name, + Type: alg, + Version: transitKey.Version, + CreatedAt: transitKey.CreatedAt, + } +} + // EncryptResponse contains the result of an encryption operation. type EncryptResponse struct { Ciphertext string `json:"ciphertext"` // Format: "version:base64-ciphertext" diff --git a/internal/transit/http/transit_key_handler.go b/internal/transit/http/transit_key_handler.go index f7a7278..77bdc14 100644 --- a/internal/transit/http/transit_key_handler.go +++ b/internal/transit/http/transit_key_handler.go @@ -5,6 +5,7 @@ import ( "fmt" "log/slog" "net/http" + "strconv" "github.com/gin-gonic/gin" "github.com/google/uuid" @@ -173,3 +174,46 @@ func (h *TransitKeyHandler) ListHandler(c *gin.Context) { response := dto.MapTransitKeysToListResponse(transitKeys, nextCursor) c.JSON(http.StatusOK, response) } + +// GetHandler retrieves transit key metadata by name and optional version. +// GET /v1/transit/keys/:name?version=1 - Requires ReadCapability. +// Returns 200 OK with transit key metadata and algorithm. +func (h *TransitKeyHandler) GetHandler(c *gin.Context) { + // Extract and validate name from URL parameter + name := c.Param("name") + if name == "" { + httputil.HandleBadRequestGin( + c, + fmt.Errorf("transit key name cannot be empty"), + h.logger, + ) + return + } + + // Extract and validate optional version from query parameter + version := uint(0) + versionStr := c.Query("version") + if versionStr != "" { + v, err := strconv.ParseUint(versionStr, 10, 32) + if err != nil { + httputil.HandleBadRequestGin( + c, + fmt.Errorf("invalid version format: must be a positive integer"), + h.logger, + ) + return + } + version = uint(v) + } + + // Call use case + transitKey, alg, err := h.transitKeyUseCase.Get(c.Request.Context(), name, version) + if err != nil { + httputil.HandleErrorGin(c, err, h.logger) + return + } + + // Map to response + response := dto.MapTransitKeyToMetadataResponse(transitKey, string(alg)) + c.JSON(http.StatusOK, response) +} diff --git a/internal/transit/http/transit_key_handler_test.go b/internal/transit/http/transit_key_handler_test.go index 38ddc80..d71f430 100644 --- a/internal/transit/http/transit_key_handler_test.go +++ b/internal/transit/http/transit_key_handler_test.go @@ -413,3 +413,109 @@ func TestTransitKeyHandler_ListHandler(t *testing.T) { assert.Equal(t, http.StatusOK, w.Code) }) } + +func TestTransitKeyHandler_GetHandler(t *testing.T) { + t.Run("Success_ValidName", func(t *testing.T) { + handler, mockUseCase := setupTestTransitKeyHandler(t) + + now := time.Now().UTC() + expectedKey := &transitDomain.TransitKey{ + ID: uuid.Must(uuid.NewV7()), + Name: "test-key", + Version: 1, + DekID: uuid.Must(uuid.NewV7()), + CreatedAt: now, + } + expectedAlg := cryptoDomain.AESGCM + + mockUseCase.EXPECT(). + Get(mock.Anything, "test-key", uint(0)). + Return(expectedKey, expectedAlg, nil). + Once() + + c, w := createTestContext(http.MethodGet, "/v1/transit/keys/test-key", nil) + c.Params = gin.Params{gin.Param{Key: "name", Value: "test-key"}} + + handler.GetHandler(c) + + assert.Equal(t, http.StatusOK, w.Code) + + var response dto.TransitKeyMetadataResponse + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, "test-key", response.Name) + assert.Equal(t, "aes-gcm", response.Type) + assert.Equal(t, uint(1), response.Version) + }) + + t.Run("Success_ValidNameAndVersion", func(t *testing.T) { + handler, mockUseCase := setupTestTransitKeyHandler(t) + + now := time.Now().UTC() + expectedKey := &transitDomain.TransitKey{ + ID: uuid.Must(uuid.NewV7()), + Name: "test-key", + Version: 2, + DekID: uuid.Must(uuid.NewV7()), + CreatedAt: now, + } + expectedAlg := cryptoDomain.ChaCha20 + + mockUseCase.EXPECT(). + Get(mock.Anything, "test-key", uint(2)). + Return(expectedKey, expectedAlg, nil). + Once() + + c, w := createTestContext(http.MethodGet, "/v1/transit/keys/test-key?version=2", nil) + c.Params = gin.Params{gin.Param{Key: "name", Value: "test-key"}} + + handler.GetHandler(c) + + assert.Equal(t, http.StatusOK, w.Code) + + var response dto.TransitKeyMetadataResponse + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, "test-key", response.Name) + assert.Equal(t, "chacha20-poly1305", response.Type) + assert.Equal(t, uint(2), response.Version) + }) + + t.Run("Error_EmptyName", func(t *testing.T) { + handler, _ := setupTestTransitKeyHandler(t) + + c, w := createTestContext(http.MethodGet, "/v1/transit/keys/", nil) + c.Params = gin.Params{gin.Param{Key: "name", Value: ""}} + + handler.GetHandler(c) + + assert.Equal(t, http.StatusBadRequest, w.Code) + }) + + t.Run("Error_InvalidVersion", func(t *testing.T) { + handler, _ := setupTestTransitKeyHandler(t) + + c, w := createTestContext(http.MethodGet, "/v1/transit/keys/test-key?version=invalid", nil) + c.Params = gin.Params{gin.Param{Key: "name", Value: "test-key"}} + + handler.GetHandler(c) + + assert.Equal(t, http.StatusBadRequest, w.Code) + }) + + t.Run("Error_NotFound", func(t *testing.T) { + handler, mockUseCase := setupTestTransitKeyHandler(t) + + mockUseCase.EXPECT(). + Get(mock.Anything, "nonexistent", uint(0)). + Return(nil, cryptoDomain.Algorithm(""), transitDomain.ErrTransitKeyNotFound). + Once() + + c, w := createTestContext(http.MethodGet, "/v1/transit/keys/nonexistent", nil) + c.Params = gin.Params{gin.Param{Key: "name", Value: "nonexistent"}} + + handler.GetHandler(c) + + assert.Equal(t, http.StatusNotFound, w.Code) + }) +} diff --git a/internal/transit/repository/mysql/mysql_transit_key_repository.go b/internal/transit/repository/mysql/mysql_transit_key_repository.go index db60859..237095e 100644 --- a/internal/transit/repository/mysql/mysql_transit_key_repository.go +++ b/internal/transit/repository/mysql/mysql_transit_key_repository.go @@ -8,6 +8,7 @@ import ( "github.com/google/uuid" + cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" "github.com/allisson/secrets/internal/database" apperrors "github.com/allisson/secrets/internal/errors" transitDomain "github.com/allisson/secrets/internal/transit/domain" @@ -155,6 +156,66 @@ func (m *MySQLTransitKeyRepository) GetByNameAndVersion( return &transitKey, nil } +// GetTransitKey retrieves a transit key version by name and optional version (0 for latest), +// including its associated encryption algorithm. Returns ErrTransitKeyNotFound if not found. +func (m *MySQLTransitKeyRepository) GetTransitKey( + ctx context.Context, + name string, + version uint, +) (*transitDomain.TransitKey, cryptoDomain.Algorithm, error) { + querier := database.GetTx(ctx, m.db) + + var query string + var args []interface{} + + if version == 0 { + query = `SELECT tk.id, tk.name, tk.version, tk.dek_id, tk.created_at, tk.deleted_at, d.algorithm + FROM transit_keys tk + JOIN deks d ON tk.dek_id = d.id + WHERE tk.name = ? AND tk.deleted_at IS NULL + ORDER BY tk.version DESC + LIMIT 1` + args = []interface{}{name} + } else { + query = `SELECT tk.id, tk.name, tk.version, tk.dek_id, tk.created_at, tk.deleted_at, d.algorithm + FROM transit_keys tk + JOIN deks d ON tk.dek_id = d.id + WHERE tk.name = ? AND tk.version = ? AND tk.deleted_at IS NULL` + args = []interface{}{name, version} + } + + var transitKey transitDomain.TransitKey + var id []byte + var dekID []byte + var algorithm cryptoDomain.Algorithm + + err := querier.QueryRowContext(ctx, query, args...).Scan( + &id, + &transitKey.Name, + &transitKey.Version, + &dekID, + &transitKey.CreatedAt, + &transitKey.DeletedAt, + &algorithm, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, "", transitDomain.ErrTransitKeyNotFound + } + return nil, "", apperrors.Wrap(err, "failed to get transit key") + } + + if err := transitKey.ID.UnmarshalBinary(id); err != nil { + return nil, "", apperrors.Wrap(err, "failed to unmarshal transit key id") + } + + if err := transitKey.DekID.UnmarshalBinary(dekID); err != nil { + return nil, "", apperrors.Wrap(err, "failed to unmarshal dek id") + } + + return &transitKey, algorithm, nil +} + // ListCursor retrieves transit keys ordered by name ascending using cursor-based pagination. // Returns the latest version for each key. func (m *MySQLTransitKeyRepository) ListCursor( diff --git a/internal/transit/repository/mysql/mysql_transit_key_repository_test.go b/internal/transit/repository/mysql/mysql_transit_key_repository_test.go index 0756ed8..b942ce7 100644 --- a/internal/transit/repository/mysql/mysql_transit_key_repository_test.go +++ b/internal/transit/repository/mysql/mysql_transit_key_repository_test.go @@ -857,6 +857,75 @@ func createTestDekMySQL(t *testing.T, db *sql.DB) uuid.UUID { return dekID } +func TestMySQLTransitKeyRepository_GetTransitKey(t *testing.T) { + db := testutil.SetupMySQLDB(t) + defer testutil.TeardownDB(t, db) + defer testutil.CleanupMySQLDB(t, db) + + repo := NewMySQLTransitKeyRepository(db) + ctx := context.Background() + + // Create prerequisite KEK and DEK + dekID := createTestDekMySQL(t, db) + algorithm := cryptoDomain.AESGCM + + name := "test-key" + + // Create version 1 + key1 := &transitDomain.TransitKey{ + ID: uuid.Must(uuid.NewV7()), + Name: name, + Version: 1, + DekID: dekID, + CreatedAt: time.Now().UTC(), + } + err := repo.Create(ctx, key1) + require.NoError(t, err) + + // Create version 2 + key2 := &transitDomain.TransitKey{ + ID: uuid.Must(uuid.NewV7()), + Name: name, + Version: 2, + DekID: dekID, + CreatedAt: time.Now().UTC().Add(time.Hour), + } + err = repo.Create(ctx, key2) + require.NoError(t, err) + + t.Run("Get latest version", func(t *testing.T) { + tk, alg, err := repo.GetTransitKey(ctx, name, 0) + require.NoError(t, err) + assert.Equal(t, key2.ID, tk.ID) + assert.Equal(t, name, tk.Name) + assert.Equal(t, uint(2), tk.Version) + assert.Equal(t, algorithm, alg) + }) + + t.Run("Get specific version", func(t *testing.T) { + tk, alg, err := repo.GetTransitKey(ctx, name, 1) + require.NoError(t, err) + assert.Equal(t, key1.ID, tk.ID) + assert.Equal(t, name, tk.Name) + assert.Equal(t, uint(1), tk.Version) + assert.Equal(t, algorithm, alg) + }) + + t.Run("Key not found", func(t *testing.T) { + tk, alg, err := repo.GetTransitKey(ctx, "non-existent", 0) + assert.ErrorIs(t, err, transitDomain.ErrTransitKeyNotFound) + assert.Nil(t, tk) + assert.Empty(t, alg) + }) + + t.Run("Version not found", func(t *testing.T) { + tk, alg, err := repo.GetTransitKey(ctx, name, 3) + assert.ErrorIs(t, err, transitDomain.ErrTransitKeyNotFound) + assert.Nil(t, tk) + assert.Empty(t, alg) + }) +} + func TestMySQLTransitKeyRepository_ListCursor_FirstPage(t *testing.T) { db := testutil.SetupMySQLDB(t) defer testutil.TeardownDB(t, db) diff --git a/internal/transit/repository/postgresql/postgresql_transit_key_repository.go b/internal/transit/repository/postgresql/postgresql_transit_key_repository.go index cccb43b..fb35f5f 100644 --- a/internal/transit/repository/postgresql/postgresql_transit_key_repository.go +++ b/internal/transit/repository/postgresql/postgresql_transit_key_repository.go @@ -10,6 +10,7 @@ import ( "github.com/google/uuid" + cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" "github.com/allisson/secrets/internal/database" apperrors "github.com/allisson/secrets/internal/errors" transitDomain "github.com/allisson/secrets/internal/transit/domain" @@ -123,6 +124,55 @@ func (p *PostgreSQLTransitKeyRepository) GetByNameAndVersion( return &transitKey, nil } +// GetTransitKey retrieves a transit key version by name and optional version (0 for latest), +// including its associated encryption algorithm. Returns ErrTransitKeyNotFound if not found. +func (p *PostgreSQLTransitKeyRepository) GetTransitKey( + ctx context.Context, + name string, + version uint, +) (*transitDomain.TransitKey, cryptoDomain.Algorithm, error) { + querier := database.GetTx(ctx, p.db) + + var query string + var args []interface{} + + if version == 0 { + query = `SELECT tk.id, tk.name, tk.version, tk.dek_id, tk.created_at, tk.deleted_at, d.algorithm + FROM transit_keys tk + JOIN deks d ON tk.dek_id = d.id + WHERE tk.name = $1 AND tk.deleted_at IS NULL + ORDER BY tk.version DESC + LIMIT 1` + args = []interface{}{name} + } else { + query = `SELECT tk.id, tk.name, tk.version, tk.dek_id, tk.created_at, tk.deleted_at, d.algorithm + FROM transit_keys tk + JOIN deks d ON tk.dek_id = d.id + WHERE tk.name = $1 AND tk.version = $2 AND tk.deleted_at IS NULL` + args = []interface{}{name, version} + } + + var transitKey transitDomain.TransitKey + var algorithm cryptoDomain.Algorithm + err := querier.QueryRowContext(ctx, query, args...).Scan( + &transitKey.ID, + &transitKey.Name, + &transitKey.Version, + &transitKey.DekID, + &transitKey.CreatedAt, + &transitKey.DeletedAt, + &algorithm, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, "", transitDomain.ErrTransitKeyNotFound + } + return nil, "", apperrors.Wrap(err, "failed to get transit key") + } + + return &transitKey, algorithm, nil +} + // ListCursor retrieves transit keys ordered by name ascending using cursor-based pagination. // Returns the latest version for each key. func (p *PostgreSQLTransitKeyRepository) ListCursor( diff --git a/internal/transit/repository/postgresql/postgresql_transit_key_repository_test.go b/internal/transit/repository/postgresql/postgresql_transit_key_repository_test.go index d446c8b..7a20233 100644 --- a/internal/transit/repository/postgresql/postgresql_transit_key_repository_test.go +++ b/internal/transit/repository/postgresql/postgresql_transit_key_repository_test.go @@ -806,6 +806,75 @@ func createTestDek(t *testing.T, db *sql.DB) uuid.UUID { return dekID } +func TestPostgreSQLTransitKeyRepository_GetTransitKey(t *testing.T) { + db := testutil.SetupPostgresDB(t) + defer testutil.TeardownDB(t, db) + defer testutil.CleanupPostgresDB(t, db) + + repo := NewPostgreSQLTransitKeyRepository(db) + ctx := context.Background() + + // Create prerequisite KEK and DEK + dekID := createTestDek(t, db) + algorithm := cryptoDomain.AESGCM + + name := "test-key" + + // Create version 1 + key1 := &transitDomain.TransitKey{ + ID: uuid.Must(uuid.NewV7()), + Name: name, + Version: 1, + DekID: dekID, + CreatedAt: time.Now().UTC(), + } + err := repo.Create(ctx, key1) + require.NoError(t, err) + + // Create version 2 + key2 := &transitDomain.TransitKey{ + ID: uuid.Must(uuid.NewV7()), + Name: name, + Version: 2, + DekID: dekID, + CreatedAt: time.Now().UTC().Add(time.Hour), + } + err = repo.Create(ctx, key2) + require.NoError(t, err) + + t.Run("Get latest version", func(t *testing.T) { + tk, alg, err := repo.GetTransitKey(ctx, name, 0) + require.NoError(t, err) + assert.Equal(t, key2.ID, tk.ID) + assert.Equal(t, name, tk.Name) + assert.Equal(t, uint(2), tk.Version) + assert.Equal(t, algorithm, alg) + }) + + t.Run("Get specific version", func(t *testing.T) { + tk, alg, err := repo.GetTransitKey(ctx, name, 1) + require.NoError(t, err) + assert.Equal(t, key1.ID, tk.ID) + assert.Equal(t, name, tk.Name) + assert.Equal(t, uint(1), tk.Version) + assert.Equal(t, algorithm, alg) + }) + + t.Run("Key not found", func(t *testing.T) { + tk, alg, err := repo.GetTransitKey(ctx, "non-existent", 0) + assert.ErrorIs(t, err, transitDomain.ErrTransitKeyNotFound) + assert.Nil(t, tk) + assert.Empty(t, alg) + }) + + t.Run("Version not found", func(t *testing.T) { + tk, alg, err := repo.GetTransitKey(ctx, name, 3) + assert.ErrorIs(t, err, transitDomain.ErrTransitKeyNotFound) + assert.Nil(t, tk) + assert.Empty(t, alg) + }) +} + func TestPostgreSQLTransitKeyRepository_ListCursor_FirstPage(t *testing.T) { db := testutil.SetupPostgresDB(t) defer testutil.TeardownDB(t, db) diff --git a/internal/transit/usecase/interface.go b/internal/transit/usecase/interface.go index ec1f1f3..947d14a 100644 --- a/internal/transit/usecase/interface.go +++ b/internal/transit/usecase/interface.go @@ -4,7 +4,6 @@ package usecase import ( "context" - "time" "github.com/google/uuid" @@ -12,41 +11,10 @@ import ( transitDomain "github.com/allisson/secrets/internal/transit/domain" ) -// DekRepository defines the interface for DEK persistence operations. -type DekRepository interface { - // Create stores a new DEK in the repository using transaction support from context. - Create(ctx context.Context, dek *cryptoDomain.Dek) error - - // Get retrieves a DEK by its ID. Returns ErrDekNotFound if not found. - Get(ctx context.Context, dekID uuid.UUID) (*cryptoDomain.Dek, error) -} - -// TransitKeyRepository defines the interface for transit key persistence. -type TransitKeyRepository interface { - // Create stores a new transit key in the repository using transaction support from context. - Create(ctx context.Context, transitKey *transitDomain.TransitKey) error - - // Delete soft deletes a transit key by marking it with DeletedAt timestamp. - Delete(ctx context.Context, transitKeyID uuid.UUID) error - - // GetByName retrieves the latest version of a transit key by name. Returns ErrTransitKeyNotFound if not found. - GetByName(ctx context.Context, name string) (*transitDomain.TransitKey, error) - - // GetByNameAndVersion retrieves a specific version of a transit key. Returns ErrTransitKeyNotFound if not found. - GetByNameAndVersion(ctx context.Context, name string, version uint) (*transitDomain.TransitKey, error) - - // ListCursor retrieves transit keys ordered by name ascending with cursor-based pagination. - // If afterName is provided, returns keys with name greater than afterName (ASC order). - // Returns the latest version for each key. Filters out soft-deleted keys. - // Returns empty slice if no keys found. Limit is pre-validated (1-1000). - ListCursor(ctx context.Context, afterName *string, limit int) ([]*transitDomain.TransitKey, error) - - // HardDelete permanently removes soft-deleted transit keys older than the specified time. - // Only affects keys where deleted_at IS NOT NULL. - // If dryRun is true, returns count without performing deletion. - // Returns the number of keys that were (or would be) deleted. - HardDelete(ctx context.Context, olderThan time.Time, dryRun bool) (int64, error) -} +// Re-export repository interfaces for convenience and backward compatibility if needed. +// However, the canonical location is now internal/transit/domain/repository.go. +type DekRepository = transitDomain.DekRepository +type TransitKeyRepository = transitDomain.TransitKeyRepository // TransitKeyUseCase defines the interface for transit encryption operations. type TransitKeyUseCase interface { @@ -58,6 +26,14 @@ type TransitKeyUseCase interface { // Generates a new DEK for the new version while preserving old versions for decryption. Rotate(ctx context.Context, name string, alg cryptoDomain.Algorithm) (*transitDomain.TransitKey, error) + // Get retrieves transit key metadata (including its algorithm) by name and optional version. + // If version is 0, the latest version is retrieved. + Get( + ctx context.Context, + name string, + version uint, + ) (*transitDomain.TransitKey, cryptoDomain.Algorithm, error) + // Delete soft deletes a transit key and all its versions by transit key ID. Delete(ctx context.Context, transitKeyID uuid.UUID) error diff --git a/internal/transit/usecase/metrics_decorator.go b/internal/transit/usecase/metrics_decorator.go index 8d6ec46..242b007 100644 --- a/internal/transit/usecase/metrics_decorator.go +++ b/internal/transit/usecase/metrics_decorator.go @@ -65,6 +65,26 @@ func (t *transitKeyUseCaseWithMetrics) Rotate( return key, err } +// Get records metrics for transit key metadata retrieval operations. +func (t *transitKeyUseCaseWithMetrics) Get( + ctx context.Context, + name string, + version uint, +) (*transitDomain.TransitKey, cryptoDomain.Algorithm, error) { + start := time.Now() + key, alg, err := t.next.Get(ctx, name, version) + + status := "success" + if err != nil { + status = "error" + } + + t.metrics.RecordOperation(ctx, "transit", "transit_key_get", status) + t.metrics.RecordDuration(ctx, "transit", "transit_key_get", time.Since(start), status) + + return key, alg, err +} + // Delete records metrics for transit key deletion operations. func (t *transitKeyUseCaseWithMetrics) Delete(ctx context.Context, transitKeyID uuid.UUID) error { start := time.Now() diff --git a/internal/transit/usecase/metrics_decorator_test.go b/internal/transit/usecase/metrics_decorator_test.go index 1457b4e..03f9817 100644 --- a/internal/transit/usecase/metrics_decorator_test.go +++ b/internal/transit/usecase/metrics_decorator_test.go @@ -301,3 +301,63 @@ func TestTransitKeyUseCaseWithMetrics_Decrypt(t *testing.T) { mockMetrics.AssertExpectations(t) }) } + +func TestTransitKeyUseCaseWithMetrics_Get(t *testing.T) { + mockNext := usecaseMocks.NewMockTransitKeyUseCase(t) + mockMetrics := &mockBusinessMetrics{} + uc := usecase.NewTransitKeyUseCaseWithMetrics(mockNext, mockMetrics) + + ctx := context.Background() + name := "test-key" + version := uint(1) + + t.Run("Get_Success", func(t *testing.T) { + // Arrange + expectedKey := &transitDomain.TransitKey{ + ID: uuid.Must(uuid.NewV7()), + Name: name, + Version: 1, + DekID: uuid.Must(uuid.NewV7()), + CreatedAt: time.Now().UTC(), + } + expectedAlg := cryptoDomain.AESGCM + + mockNext.EXPECT().Get(ctx, name, version).Return(expectedKey, expectedAlg, nil).Once() + mockMetrics.On("RecordOperation", ctx, "transit", "transit_key_get", "success").Return().Once() + mockMetrics.On("RecordDuration", ctx, "transit", "transit_key_get", mock.AnythingOfType("time.Duration"), "success"). + Return(). + Once() + + // Act + key, alg, err := uc.Get(ctx, name, version) + + // Assert + assert.NoError(t, err) + assert.Equal(t, expectedKey, key) + assert.Equal(t, expectedAlg, alg) + mockNext.AssertExpectations(t) + mockMetrics.AssertExpectations(t) + }) + + t.Run("Get_Error", func(t *testing.T) { + // Arrange + expectedErr := errors.New("get failed") + + mockNext.EXPECT().Get(ctx, name, version).Return(nil, cryptoDomain.Algorithm(""), expectedErr).Once() + mockMetrics.On("RecordOperation", ctx, "transit", "transit_key_get", "error").Return().Once() + mockMetrics.On("RecordDuration", ctx, "transit", "transit_key_get", mock.AnythingOfType("time.Duration"), "error"). + Return(). + Once() + + // Act + key, alg, err := uc.Get(ctx, name, version) + + // Assert + assert.Error(t, err) + assert.Nil(t, key) + assert.Equal(t, cryptoDomain.Algorithm(""), alg) + assert.Equal(t, expectedErr, err) + mockNext.AssertExpectations(t) + mockMetrics.AssertExpectations(t) + }) +} diff --git a/internal/transit/usecase/mocks/mocks.go b/internal/transit/usecase/mocks/mocks.go index 8a1f68b..afbdd9a 100644 --- a/internal/transit/usecase/mocks/mocks.go +++ b/internal/transit/usecase/mocks/mocks.go @@ -444,11 +444,92 @@ func (_c *MockTransitKeyRepository_GetByNameAndVersion_Call) Return(transitKey * return _c } -func (_c *MockTransitKeyRepository_GetByNameAndVersion_Call) RunAndReturn(run func(ctx context.Context, name string, version uint) (*domain0.TransitKey, error)) *MockTransitKeyRepository_GetByNameAndVersion_Call { +func (_c *MockTransitKeyRepository_GetByNameAndVersion_Call) RunAndReturn(run func(context.Context, string, uint) (*domain0.TransitKey, error)) *MockTransitKeyRepository_GetByNameAndVersion_Call { _c.Call.Return(run) return _c } +// GetTransitKey provides a mock function for the type MockTransitKeyRepository +func (_mock *MockTransitKeyRepository) GetTransitKey(ctx context.Context, name string, version uint) (*domain0.TransitKey, domain.Algorithm, error) { + ret := _mock.Called(ctx, name, version) + + if len(ret) == 0 { + panic("no return value specified for GetTransitKey") + } + + var r0 *domain0.TransitKey + var r1 domain.Algorithm + var r2 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) (*domain0.TransitKey, domain.Algorithm, error)); ok { + return returnFunc(ctx, name, version) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) *domain0.TransitKey); ok { + r0 = returnFunc(ctx, name, version) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*domain0.TransitKey) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, uint) domain.Algorithm); ok { + r1 = returnFunc(ctx, name, version) + } else { + r1 = ret.Get(1).(domain.Algorithm) + } + if returnFunc, ok := ret.Get(2).(func(context.Context, string, uint) error); ok { + r2 = returnFunc(ctx, name, version) + } else { + r2 = ret.Error(2) + } + return r0, r1, r2 +} + +// MockTransitKeyRepository_GetTransitKey_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetTransitKey' +type MockTransitKeyRepository_GetTransitKey_Call struct { + *mock.Call +} + +// GetTransitKey is a helper method to define mock.On call +// - ctx context.Context +// - name string +// - version uint +func (_e *MockTransitKeyRepository_Expecter) GetTransitKey(ctx interface{}, name interface{}, version interface{}) *MockTransitKeyRepository_GetTransitKey_Call { + return &MockTransitKeyRepository_GetTransitKey_Call{Call: _e.mock.On("GetTransitKey", ctx, name, version)} +} + +func (_c *MockTransitKeyRepository_GetTransitKey_Call) Run(run func(ctx context.Context, name string, version uint)) *MockTransitKeyRepository_GetTransitKey_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 uint + if args[2] != nil { + arg2 = args[2].(uint) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *MockTransitKeyRepository_GetTransitKey_Call) Return(_a0 *domain0.TransitKey, _a1 domain.Algorithm, _a2 error) *MockTransitKeyRepository_GetTransitKey_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockTransitKeyRepository_GetTransitKey_Call) RunAndReturn(run func(ctx context.Context, name string, version uint) (*domain0.TransitKey, domain.Algorithm, error)) *MockTransitKeyRepository_GetTransitKey_Call { + _c.Call.Return(run) + return _c +} + +// HardDelete provides a mock function for the type MockTransitKeyRepository // HardDelete provides a mock function for the type MockTransitKeyRepository func (_mock *MockTransitKeyRepository) HardDelete(ctx context.Context, olderThan time.Time, dryRun bool) (int64, error) { ret := _mock.Called(ctx, olderThan, dryRun) @@ -776,6 +857,86 @@ func (_c *MockTransitKeyUseCase_Decrypt_Call) RunAndReturn(run func(ctx context. return _c } +// Get provides a mock function for the type MockTransitKeyUseCase +func (_mock *MockTransitKeyUseCase) Get(ctx context.Context, name string, version uint) (*domain0.TransitKey, domain.Algorithm, error) { + ret := _mock.Called(ctx, name, version) + + if len(ret) == 0 { + panic("no return value specified for Get") + } + + var r0 *domain0.TransitKey + var r1 domain.Algorithm + var r2 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) (*domain0.TransitKey, domain.Algorithm, error)); ok { + return returnFunc(ctx, name, version) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) *domain0.TransitKey); ok { + r0 = returnFunc(ctx, name, version) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*domain0.TransitKey) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, uint) domain.Algorithm); ok { + r1 = returnFunc(ctx, name, version) + } else { + r1 = ret.Get(1).(domain.Algorithm) + } + if returnFunc, ok := ret.Get(2).(func(context.Context, string, uint) error); ok { + r2 = returnFunc(ctx, name, version) + } else { + r2 = ret.Error(2) + } + return r0, r1, r2 +} + +// MockTransitKeyUseCase_Get_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Get' +type MockTransitKeyUseCase_Get_Call struct { + *mock.Call +} + +// Get is a helper method to define mock.On call +// - ctx context.Context +// - name string +// - version uint +func (_e *MockTransitKeyUseCase_Expecter) Get(ctx interface{}, name interface{}, version interface{}) *MockTransitKeyUseCase_Get_Call { + return &MockTransitKeyUseCase_Get_Call{Call: _e.mock.On("Get", ctx, name, version)} +} + +func (_c *MockTransitKeyUseCase_Get_Call) Run(run func(ctx context.Context, name string, version uint)) *MockTransitKeyUseCase_Get_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 uint + if args[2] != nil { + arg2 = args[2].(uint) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *MockTransitKeyUseCase_Get_Call) Return(_a0 *domain0.TransitKey, _a1 domain.Algorithm, _a2 error) *MockTransitKeyUseCase_Get_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockTransitKeyUseCase_Get_Call) RunAndReturn(run func(ctx context.Context, name string, version uint) (*domain0.TransitKey, domain.Algorithm, error)) *MockTransitKeyUseCase_Get_Call { + _c.Call.Return(run) + return _c +} + // Delete provides a mock function for the type MockTransitKeyUseCase func (_mock *MockTransitKeyUseCase) Delete(ctx context.Context, transitKeyID uuid.UUID) error { ret := _mock.Called(ctx, transitKeyID) @@ -913,6 +1074,8 @@ func (_c *MockTransitKeyUseCase_Encrypt_Call) RunAndReturn(run func(ctx context. return _c } +// Get provides a mock function for the type MockTransitKeyUseCase + // ListCursor provides a mock function for the type MockTransitKeyUseCase func (_mock *MockTransitKeyUseCase) ListCursor(ctx context.Context, afterName *string, limit int) ([]*domain0.TransitKey, error) { ret := _mock.Called(ctx, afterName, limit) diff --git a/internal/transit/usecase/transit_key_usecase.go b/internal/transit/usecase/transit_key_usecase.go index 467f8e7..87259ef 100644 --- a/internal/transit/usecase/transit_key_usecase.go +++ b/internal/transit/usecase/transit_key_usecase.go @@ -151,6 +151,15 @@ func (t *transitKeyUseCase) Rotate( return newTransitKey, nil } +// Get retrieves transit key metadata (including its algorithm) by name and optional version. +func (t *transitKeyUseCase) Get( + ctx context.Context, + name string, + version uint, +) (*transitDomain.TransitKey, cryptoDomain.Algorithm, error) { + return t.transitRepo.GetTransitKey(ctx, name, version) +} + // Delete soft-deletes a transit key by setting its deleted_at timestamp. func (t *transitKeyUseCase) Delete(ctx context.Context, transitKeyID uuid.UUID) error { return t.transitRepo.Delete(ctx, transitKeyID) diff --git a/internal/transit/usecase/transit_key_usecase_test.go b/internal/transit/usecase/transit_key_usecase_test.go index 07b6fb7..18f72f3 100644 --- a/internal/transit/usecase/transit_key_usecase_test.go +++ b/internal/transit/usecase/transit_key_usecase_test.go @@ -1462,3 +1462,75 @@ func TestTransitKeyUseCase_PurgeDeleted(t *testing.T) { assert.Equal(t, expectedError, err) }) } + +// TestTransitKeyUseCase_Get tests the Get method of transitKeyUseCase. +func TestTransitKeyUseCase_Get(t *testing.T) { + ctx := context.Background() + + t.Run("Success_GetTransitKey", func(t *testing.T) { + // Setup mocks + mockTxManager := databaseMocks.NewMockTxManager(t) + mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) + mockDekRepo := usecaseMocks.NewMockDekRepository(t) + mockKeyManager := serviceMocks.NewMockKeyManager(t) + mockAeadManager := serviceMocks.NewMockAEADManager(t) + + // Create test data + kek := createTestKek() + kekChain := createTestKekChain(kek.ID, kek) + defer kekChain.Close() + + expectedKey := createTestTransitKey("test-key", 1, uuid.Must(uuid.NewV7())) + expectedAlg := cryptoDomain.AESGCM + + // Setup expectations + mockTransitRepo.EXPECT(). + GetTransitKey(ctx, "test-key", uint(1)). + Return(expectedKey, expectedAlg, nil). + Once() + + // Execute + uc := NewTransitKeyUseCase( + mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, + ) + key, alg, err := uc.Get(ctx, "test-key", 1) + + // Assert + assert.NoError(t, err) + assert.NotNil(t, key) + assert.Equal(t, expectedKey, key) + assert.Equal(t, expectedAlg, alg) + }) + + t.Run("Error_GetTransitKeyNotFound", func(t *testing.T) { + // Setup mocks + mockTxManager := databaseMocks.NewMockTxManager(t) + mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) + mockDekRepo := usecaseMocks.NewMockDekRepository(t) + mockKeyManager := serviceMocks.NewMockKeyManager(t) + mockAeadManager := serviceMocks.NewMockAEADManager(t) + + // Create test data + kek := createTestKek() + kekChain := createTestKekChain(kek.ID, kek) + defer kekChain.Close() + + // Setup expectations + mockTransitRepo.EXPECT(). + GetTransitKey(ctx, "test-key", uint(1)). + Return(nil, cryptoDomain.Algorithm(""), transitDomain.ErrTransitKeyNotFound). + Once() + + // Execute + uc := NewTransitKeyUseCase( + mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, + ) + key, alg, err := uc.Get(ctx, "test-key", 1) + + // Assert + assert.Error(t, err) + assert.Nil(t, key) + assert.Equal(t, cryptoDomain.Algorithm(""), alg) + assert.True(t, apperrors.Is(err, transitDomain.ErrTransitKeyNotFound)) + }) +} diff --git a/test/integration/transit_flow_test.go b/test/integration/transit_flow_test.go index e751787..acc997b 100644 --- a/test/integration/transit_flow_test.go +++ b/test/integration/transit_flow_test.go @@ -75,8 +75,22 @@ func TestIntegration_Transit_CompleteFlow(t *testing.T) { transitKeyID = parsedID }) - // [2/8] Test POST /v1/transit/keys/:name/encrypt - Encrypt with transit key - t.Run("02_Encrypt", func(t *testing.T) { + // [2/11] Test GET /v1/transit/keys/:name - Get transit key + t.Run("02_GetTransitKey", func(t *testing.T) { + resp, body := ctx.makeRequest(t, http.MethodGet, "/v1/transit/keys/"+transitKeyName, nil, true) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var response transitDTO.TransitKeyMetadataResponse + err := json.Unmarshal(body, &response) + require.NoError(t, err) + assert.Equal(t, transitKeyName, response.Name) + assert.Equal(t, "aes-gcm", response.Type) + assert.Equal(t, uint(1), response.Version) + assert.False(t, response.CreatedAt.IsZero()) + }) + + // [3/11] Test POST /v1/transit/keys/:name/encrypt - Encrypt with transit key + t.Run("03_Encrypt", func(t *testing.T) { requestBody := transitDTO.EncryptRequest{ Plaintext: base64.StdEncoding.EncodeToString(plaintext1), } @@ -180,8 +194,20 @@ func TestIntegration_Transit_CompleteFlow(t *testing.T) { assert.Equal(t, uint(2), response.Version) // Version should increment to 2 }) - // [6/8] Test POST /v1/transit/keys/:name/encrypt - Encrypt with rotated key (version 2) - t.Run("06_EncryptWithRotatedKey", func(t *testing.T) { + // [7/11] Test GET /v1/transit/keys/:name?version=1 - Get specific version + t.Run("07_GetSpecificVersion", func(t *testing.T) { + resp, body := ctx.makeRequest(t, http.MethodGet, "/v1/transit/keys/"+transitKeyName+"?version=1", nil, true) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var response transitDTO.TransitKeyMetadataResponse + err := json.Unmarshal(body, &response) + require.NoError(t, err) + assert.Equal(t, transitKeyName, response.Name) + assert.Equal(t, uint(1), response.Version) + }) + + // [8/11] Test POST /v1/transit/keys/:name/encrypt - Encrypt with rotated key (version 2) + t.Run("08_EncryptWithRotatedKey", func(t *testing.T) { requestBody := transitDTO.EncryptRequest{ Plaintext: base64.StdEncoding.EncodeToString(plaintext1), } @@ -208,8 +234,8 @@ func TestIntegration_Transit_CompleteFlow(t *testing.T) { assert.NotEqual(t, ciphertext1, ciphertextV2) }) - // [7/8] Test POST /v1/transit/keys/:name/decrypt - Decrypt old ciphertext (backward compatibility) - t.Run("07_DecryptOldCiphertext", func(t *testing.T) { + // [9/11] Test POST /v1/transit/keys/:name/decrypt - Decrypt old ciphertext (backward compatibility) + t.Run("09_DecryptOldCiphertext", func(t *testing.T) { requestBody := transitDTO.DecryptRequest{ Ciphertext: ciphertext1, // Use version 1 ciphertext } @@ -235,8 +261,8 @@ func TestIntegration_Transit_CompleteFlow(t *testing.T) { assert.Equal(t, plaintext1, decoded) }) - // [8/8] Test AEAD Context - t.Run("08_AEADContext", func(t *testing.T) { + // [10/11] Test AEAD Context + t.Run("10_AEADContext", func(t *testing.T) { contextAAD := []byte("integration-test-context") wrongContext := []byte("wrong-context") @@ -296,8 +322,8 @@ func TestIntegration_Transit_CompleteFlow(t *testing.T) { assert.Equal(t, http.StatusUnprocessableEntity, resp.StatusCode) }) - // [9/9] Test DELETE /v1/transit/keys/:id - Delete transit key - t.Run("09_DeleteTransitKey", func(t *testing.T) { + // [11/11] Test DELETE /v1/transit/keys/:id - Delete transit key + t.Run("11_DeleteTransitKey", func(t *testing.T) { resp, body := ctx.makeRequest( t, http.MethodDelete, @@ -309,7 +335,8 @@ func TestIntegration_Transit_CompleteFlow(t *testing.T) { assert.Empty(t, body) }) - t.Logf("All 9 transit endpoint tests passed for %s", tc.dbDriver) + t.Logf("All 11 transit endpoint tests passed for %s", tc.dbDriver) + }) } }