diff --git a/broker/api/api-handler.go b/broker/api/api-handler.go index 11b078f2..6e1b9148 100644 --- a/broker/api/api-handler.go +++ b/broker/api/api-handler.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "net/http" - "net/url" "strings" "time" @@ -852,41 +851,3 @@ func toString(text pgtype.Text) *string { return nil } } - -func ToLinkUrlValues(r *http.Request, urlValues url.Values) string { - return toLinkPath(r, r.URL.Path, urlValues.Encode()) -} - -func toLink(r *http.Request, path string, id string, query string) string { - if strings.Contains(r.RequestURI, "/broker/") { - path = "/broker" + path - } - if id != "" { - path = path + "/" + id - } - return toLinkPath(r, path, query) -} - -func toLinkPath(r *http.Request, path string, query string) string { - if query != "" { - path = path + "?" + query - } - urlScheme := r.Header.Get("X-Forwarded-Proto") - if len(urlScheme) == 0 { - urlScheme = r.URL.Scheme - } - if len(urlScheme) == 0 { - urlScheme = "https" - } - urlHost := r.Header.Get("X-Forwarded-Host") - if len(urlHost) == 0 { - urlHost = r.URL.Host - } - if len(urlHost) == 0 { - urlHost = r.Host - } - if strings.Contains(urlHost, "localhost") { - urlScheme = "http" - } - return urlScheme + "://" + urlHost + path -} diff --git a/broker/api/common.go b/broker/api/common.go index 7c39aa81..883eb964 100644 --- a/broker/api/common.go +++ b/broker/api/common.go @@ -3,6 +3,7 @@ package api import ( "errors" "net/http" + "net/url" "strconv" "strings" @@ -10,6 +11,44 @@ import ( "github.com/indexdata/crosslink/broker/oapi" ) +func ToLinkUrlValues(r *http.Request, urlValues url.Values) string { + return ToLinkPath(r, r.URL.Path, urlValues.Encode()) +} + +func toLink(r *http.Request, path string, id string, query string) string { + if strings.Contains(r.RequestURI, "/broker/") { + path = "/broker" + path + } + if id != "" { + path = path + "/" + id + } + return ToLinkPath(r, path, query) +} + +func ToLinkPath(r *http.Request, path string, query string) string { + if query != "" { + path = path + "?" + query + } + urlScheme := r.Header.Get("X-Forwarded-Proto") + if len(urlScheme) == 0 { + urlScheme = r.URL.Scheme + } + if len(urlScheme) == 0 { + urlScheme = "https" + } + urlHost := r.Header.Get("X-Forwarded-Host") + if len(urlHost) == 0 { + urlHost = r.URL.Host + } + if len(urlHost) == 0 { + urlHost = r.Host + } + if strings.Contains(urlHost, "localhost") { + urlScheme = "http" + } + return urlScheme + "://" + urlHost + path +} + func CollectAboutData(fullCount int64, offset int32, limit int32, r *http.Request) oapi.About { about := oapi.About{} about.Count = fullCount diff --git a/broker/patron_request/api/api-handler.go b/broker/patron_request/api/api-handler.go index b1060956..d9d5de34 100644 --- a/broker/patron_request/api/api-handler.go +++ b/broker/patron_request/api/api-handler.go @@ -20,7 +20,9 @@ import ( prservice "github.com/indexdata/crosslink/broker/patron_request/service" "github.com/indexdata/crosslink/iso18626" "github.com/indexdata/go-utils/utils" + "github.com/jackc/pgerrcode" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" ) @@ -183,8 +185,13 @@ func (a *PatronRequestApiHandler) PostPatronRequests(w http.ResponseWriter, r *h addInternalError(ctx, w, err) return } - pr, err := a.prRepo.SavePatronRequest(ctx, (pr_db.SavePatronRequestParams)(dbreq)) + pr, err := a.prRepo.CreatePatronRequest(ctx, (pr_db.CreatePatronRequestParams)(dbreq)) if err != nil { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) && pgerrcode.IsIntegrityConstraintViolation(pgErr.Code) { + addBadRequestError(ctx, w, errors.New("a patron request with this ID already exists")) + return + } addInternalError(ctx, w, err) return } @@ -205,6 +212,7 @@ func (a *PatronRequestApiHandler) PostPatronRequests(w http.ResponseWriter, r *h addInternalError(ctx, w, err) return } + w.Header().Set("Location", api.ToLinkPath(r, r.URL.Path+"/"+pr.ID, "")) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) _ = json.NewEncoder(w).Encode(toApiPatronRequest(pr, illRequest)) diff --git a/broker/patron_request/api/api-handler_test.go b/broker/patron_request/api/api-handler_test.go index 81e7b7c9..5d88070d 100644 --- a/broker/patron_request/api/api-handler_test.go +++ b/broker/patron_request/api/api-handler_test.go @@ -341,18 +341,26 @@ func (r *PrRepoError) GetPatronRequestById(ctx common.ExtendedContext, id string return pr_db.PatronRequest{}, errors.New("DB error") } } + func (r *PrRepoError) ListPatronRequests(ctx common.ExtendedContext, args pr_db.ListPatronRequestsParams, cql *string) ([]pr_db.PatronRequest, int64, error) { return []pr_db.PatronRequest{}, 0, errors.New("DB error") } -func (r *PrRepoError) SavePatronRequest(ctx common.ExtendedContext, params pr_db.SavePatronRequestParams) (pr_db.PatronRequest, error) { + +func (r *PrRepoError) UpdatePatronRequest(ctx common.ExtendedContext, params pr_db.UpdatePatronRequestParams) (pr_db.PatronRequest, error) { return pr_db.PatronRequest{}, errors.New("DB error") } + +func (r *PrRepoError) CreatePatronRequest(ctx common.ExtendedContext, params pr_db.CreatePatronRequestParams) (pr_db.PatronRequest, error) { + return pr_db.PatronRequest{}, errors.New("DB error") +} + func (r *PrRepoError) DeletePatronRequest(ctx common.ExtendedContext, id string) error { if id == "4" { return nil } return errors.New("DB error") } + func (r *PrRepoError) GetNextHrid(ctx common.ExtendedContext, prefix string) (string, error) { r.counter++ return strings.ToUpper(prefix) + "-" + strconv.FormatInt(r.counter, 10), nil diff --git a/broker/patron_request/db/prrepo.go b/broker/patron_request/db/prrepo.go index a68926ac..b8c8fe45 100644 --- a/broker/patron_request/db/prrepo.go +++ b/broker/patron_request/db/prrepo.go @@ -12,7 +12,8 @@ type PrRepo interface { repo.Transactional[PrRepo] GetPatronRequestById(ctx common.ExtendedContext, id string) (PatronRequest, error) ListPatronRequests(ctx common.ExtendedContext, args ListPatronRequestsParams, cql *string) ([]PatronRequest, int64, error) - SavePatronRequest(ctx common.ExtendedContext, params SavePatronRequestParams) (PatronRequest, error) + UpdatePatronRequest(ctx common.ExtendedContext, params UpdatePatronRequestParams) (PatronRequest, error) + CreatePatronRequest(ctx common.ExtendedContext, params CreatePatronRequestParams) (PatronRequest, error) DeletePatronRequest(ctx common.ExtendedContext, id string) error GetPatronRequestBySupplierSymbolAndRequesterReqId(ctx common.ExtendedContext, supplierSymbol string, requesterReId string) (PatronRequest, error) GetNextHrid(ctx common.ExtendedContext, prefix string) (string, error) @@ -69,8 +70,12 @@ func (r *PgPrRepo) ListPatronRequests(ctx common.ExtendedContext, params ListPat return list, fullCount, err } -func (r *PgPrRepo) SavePatronRequest(ctx common.ExtendedContext, params SavePatronRequestParams) (PatronRequest, error) { - row, err := r.queries.SavePatronRequest(ctx, r.GetConnOrTx(), params) +func (r *PgPrRepo) UpdatePatronRequest(ctx common.ExtendedContext, params UpdatePatronRequestParams) (PatronRequest, error) { + row, err := r.queries.UpdatePatronRequest(ctx, r.GetConnOrTx(), params) + return row.PatronRequest, err +} +func (r *PgPrRepo) CreatePatronRequest(ctx common.ExtendedContext, params CreatePatronRequestParams) (PatronRequest, error) { + row, err := r.queries.CreatePatronRequest(ctx, r.GetConnOrTx(), params) return row.PatronRequest, err } diff --git a/broker/patron_request/service/action.go b/broker/patron_request/service/action.go index 03e0edfe..7d911099 100644 --- a/broker/patron_request/service/action.go +++ b/broker/patron_request/service/action.go @@ -212,7 +212,7 @@ func (a *PatronRequestActionService) handleLenderAction(ctx common.ExtendedConte func (a *PatronRequestActionService) updateStateAndReturnResult(ctx common.ExtendedContext, pr pr_db.PatronRequest, state pr_db.PatronRequestState, result *events.EventResult) (events.EventStatus, *events.EventResult) { pr.State = state - pr, err := a.prRepo.SavePatronRequest(ctx, pr_db.SavePatronRequestParams(pr)) + pr, err := a.prRepo.UpdatePatronRequest(ctx, pr_db.UpdatePatronRequestParams(pr)) if err != nil { return events.LogErrorAndReturnResult(ctx, "failed to update patron request", err) } diff --git a/broker/patron_request/service/action_test.go b/broker/patron_request/service/action_test.go index cc8596fd..bfdc1c57 100644 --- a/broker/patron_request/service/action_test.go +++ b/broker/patron_request/service/action_test.go @@ -677,7 +677,15 @@ func (r *MockPrRepo) GetPatronRequestById(ctx common.ExtendedContext, id string) return args.Get(0).(pr_db.PatronRequest), args.Error(1) } -func (r *MockPrRepo) SavePatronRequest(ctx common.ExtendedContext, params pr_db.SavePatronRequestParams) (pr_db.PatronRequest, error) { +func (r *MockPrRepo) UpdatePatronRequest(ctx common.ExtendedContext, params pr_db.UpdatePatronRequestParams) (pr_db.PatronRequest, error) { + if strings.Contains(params.ID, "error") || strings.Contains(params.RequesterReqID.String, "error") { + return pr_db.PatronRequest{}, errors.New("db error") + } + r.savedPr = pr_db.PatronRequest(params) + return pr_db.PatronRequest(params), nil +} + +func (r *MockPrRepo) CreatePatronRequest(ctx common.ExtendedContext, params pr_db.CreatePatronRequestParams) (pr_db.PatronRequest, error) { if strings.Contains(params.ID, "error") || strings.Contains(params.RequesterReqID.String, "error") { return pr_db.PatronRequest{}, errors.New("db error") } diff --git a/broker/patron_request/service/message-handler.go b/broker/patron_request/service/message-handler.go index 5fc20c2e..b1c96d16 100644 --- a/broker/patron_request/service/message-handler.go +++ b/broker/patron_request/service/message-handler.go @@ -134,7 +134,7 @@ func (m *PatronRequestMessageHandler) handleSupplyingAgencyMessage(ctx common.Ex } func (m *PatronRequestMessageHandler) updatePatronRequestAndCreateSamResponse(ctx common.ExtendedContext, pr pr_db.PatronRequest, sam iso18626.SupplyingAgencyMessage) (events.EventStatus, *iso18626.ISO18626Message, error) { - _, err := m.prRepo.SavePatronRequest(ctx, pr_db.SavePatronRequestParams(pr)) + _, err := m.prRepo.UpdatePatronRequest(ctx, pr_db.UpdatePatronRequestParams(pr)) if err != nil { return createSAMResponse(sam, iso18626.TypeMessageStatusERROR, &iso18626.ErrorData{ ErrorType: iso18626.TypeErrorTypeUnrecognisedDataValue, @@ -214,7 +214,7 @@ func (m *PatronRequestMessageHandler) handleRequestMessage(ctx common.ExtendedCo ErrorValue: err.Error(), }, err) } - pr, err := m.prRepo.SavePatronRequest(ctx, pr_db.SavePatronRequestParams{ + pr, err := m.prRepo.CreatePatronRequest(ctx, pr_db.CreatePatronRequestParams{ ID: uuid.NewString(), Timestamp: pgtype.Timestamp{Valid: true, Time: time.Now()}, State: LenderStateNew, @@ -293,7 +293,7 @@ func createRAMResponse(ram iso18626.RequestingAgencyMessage, messageStatus iso18 } func (m *PatronRequestMessageHandler) updatePatronRequestAndCreateRamResponse(ctx common.ExtendedContext, pr pr_db.PatronRequest, ram iso18626.RequestingAgencyMessage, action *iso18626.TypeAction) (events.EventStatus, *iso18626.ISO18626Message, error) { - _, err := m.prRepo.SavePatronRequest(ctx, pr_db.SavePatronRequestParams(pr)) + _, err := m.prRepo.UpdatePatronRequest(ctx, pr_db.UpdatePatronRequestParams(pr)) if err != nil { return createRAMResponse(ram, iso18626.TypeMessageStatusERROR, action, &iso18626.ErrorData{ ErrorType: iso18626.TypeErrorTypeUnrecognisedDataValue, diff --git a/broker/sqlc/pr_query.sql b/broker/sqlc/pr_query.sql index bc50a2dc..5b9322bb 100644 --- a/broker/sqlc/pr_query.sql +++ b/broker/sqlc/pr_query.sql @@ -10,19 +10,23 @@ FROM patron_request ORDER BY timestamp LIMIT $1 OFFSET $2; --- name: SavePatronRequest :one +-- name: UpdatePatronRequest :one +UPDATE patron_request +SET timestamp = $2, + ill_request = $3, + state = $4, + side = $5, + patron = $6, + requester_symbol = $7, + supplier_symbol = $8, + tenant = $9, + requester_req_id = $10 +WHERE id = $1 +RETURNING sqlc.embed(patron_request); + +-- name: CreatePatronRequest :one INSERT INTO patron_request (id, timestamp, ill_request, state, side, patron, requester_symbol, supplier_symbol, tenant, requester_req_id) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) -ON CONFLICT (id) DO UPDATE - SET timestamp = EXCLUDED.timestamp, - ill_request = EXCLUDED.ill_request, - state = EXCLUDED.state, - side = EXCLUDED.side, - patron = EXCLUDED.patron, - requester_symbol = EXCLUDED.requester_symbol, - supplier_symbol = EXCLUDED.supplier_symbol, - tenant = EXCLUDED.tenant, - requester_req_id = EXCLUDED.requester_req_id RETURNING sqlc.embed(patron_request); -- name: DeletePatronRequest :exec diff --git a/broker/test/patron_request/api/api-handler_test.go b/broker/test/patron_request/api/api-handler_test.go index b164a1e0..949b0031 100644 --- a/broker/test/patron_request/api/api-handler_test.go +++ b/broker/test/patron_request/api/api-handler_test.go @@ -112,7 +112,11 @@ func TestCrud(t *testing.T) { newPrBytes, err := json.Marshal(newPr) assert.NoError(t, err, "failed to marshal patron request") - respBytes := httpRequest(t, "POST", basePath, newPrBytes, 201) + hres, respBytes := httpRequest2(t, "POST", basePath, newPrBytes, 201) + // Check Location header + location := hres.Header.Get("Location") + assert.NotEmpty(t, location, "Location header should be set") + assert.Equal(t, getLocalhostWithPort()+"/patron_requests/"+id, location) var foundPr proapi.PatronRequest err = json.Unmarshal(respBytes, &foundPr) @@ -125,6 +129,9 @@ func TestCrud(t *testing.T) { assert.Equal(t, *newPr.SupplierSymbol, *foundPr.SupplierSymbol) assert.Equal(t, *newPr.Patron, *foundPr.Patron) + respBytes = httpRequest(t, "POST", basePath, newPrBytes, 400) + assert.Contains(t, string(respBytes), "a patron request with this ID already exists") + // GET list queryParams := "?side=borrowing&symbol=" + *foundPr.RequesterSymbol respBytes = httpRequest(t, "GET", basePath+queryParams, []byte{}, 200) @@ -408,7 +415,7 @@ func TestGetReturnableStateModel(t *testing.T) { assert.Equal(t, len(*returnablesStateModel.States), len(*retrievedStateModel.States)) } -func httpRequest(t *testing.T, method string, uriPath string, reqbytes []byte, expectStatus int) []byte { +func httpRequest2(t *testing.T, method string, uriPath string, reqbytes []byte, expectStatus int) (*http.Response, []byte) { client := http.DefaultClient hreq, err := http.NewRequest(method, getLocalhostWithPort()+uriPath, bytes.NewBuffer(reqbytes)) assert.NoError(t, err) @@ -423,7 +430,12 @@ func httpRequest(t *testing.T, method string, uriPath string, reqbytes []byte, e body, err := io.ReadAll(hres.Body) assert.Equal(t, expectStatus, hres.StatusCode, string(body)) assert.NoError(t, err) - return body + return hres, body +} + +func httpRequest(t *testing.T, method string, uriPath string, reqbytes []byte, expectStatus int) []byte { + _, respBytes := httpRequest2(t, method, uriPath, reqbytes, expectStatus) + return respBytes } func getLocalhostWithPort() string { diff --git a/broker/test/patron_request/db/prrepo_test.go b/broker/test/patron_request/db/prrepo_test.go index fc005162..1d1c4285 100644 --- a/broker/test/patron_request/db/prrepo_test.go +++ b/broker/test/patron_request/db/prrepo_test.go @@ -68,7 +68,7 @@ func TestMain(m *testing.M) { func TestItem(t *testing.T) { prId := uuid.NewString() - _, err := prRepo.SavePatronRequest(appCtx, pr_db.SavePatronRequestParams{ + _, err := prRepo.CreatePatronRequest(appCtx, pr_db.CreatePatronRequestParams{ ID: prId, Timestamp: pgtype.Timestamp{ Time: time.Now(), @@ -156,7 +156,7 @@ func TestItem(t *testing.T) { func TestNotification(t *testing.T) { prId := uuid.NewString() - _, err := prRepo.SavePatronRequest(appCtx, pr_db.SavePatronRequestParams{ + _, err := prRepo.CreatePatronRequest(appCtx, pr_db.CreatePatronRequestParams{ ID: prId, Timestamp: pgtype.Timestamp{ Time: time.Now(),