diff --git a/.github/workflows/binary-gems.yml b/.github/workflows/binary-gems.yml index 175394262..870182159 100644 --- a/.github/workflows/binary-gems.yml +++ b/.github/workflows/binary-gems.yml @@ -111,6 +111,11 @@ jobs: brew: "postgresql" # macOS mingw: "postgresql" # Windows mingw / mswin /ucrt + - name: Install postgresql server headers Ubuntu + if: startsWith(matrix.os, 'ubuntu-') + run: | + sudo apt-get -y --allow-downgrades install '^postgresql-server-dev-[0-9]+$' libkrb5-dev + - name: Set up 32 bit x86 Ruby if: matrix.platform == 'x86-mingw32' run: | diff --git a/.github/workflows/source-gem.yml b/.github/workflows/source-gem.yml index 7c1c3deae..4e5eafb91 100644 --- a/.github/workflows/source-gem.yml +++ b/.github/workflows/source-gem.yml @@ -127,9 +127,14 @@ jobs: echo "deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVER" | sudo tee -a /etc/apt/sources.list.d/pgdg.list wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - sudo apt-get -y update - sudo apt-get -y --allow-downgrades install postgresql-$PGVER libpq5=$PGVER* libpq-dev=$PGVER* + sudo apt-get -y --allow-downgrades install postgresql-$PGVER libpq5=$PGVER* libpq-dev=$PGVER* postgresql-server-dev-$PGVER libkrb5-dev echo /usr/lib/postgresql/$PGVER/bin >> $GITHUB_PATH + - name: Download OAuth support Ubuntu + if: matrix.os == 'ubuntu' && matrix.PGVER >= 18 + run: | + sudo apt-get -y --allow-downgrades install libpq-oauth=$PGVER* + - name: Download PostgreSQL Macos if: matrix.os == 'macos' run: | diff --git a/.gitignore b/.gitignore index 1de9c8ee9..b82fc544c 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,9 @@ /lib/2.?/ /lib/3.?/ /pkg/ +/spec/oauth/*.bc +/spec/oauth/*.o +/spec/oauth/*.so /tmp/ /tmp_test_*/ /vendor/ diff --git a/Gemfile b/Gemfile index c22c0988b..8cd0309ad 100644 --- a/Gemfile +++ b/Gemfile @@ -15,6 +15,7 @@ group :test do gem "rake-compiler", "~> 1.0" gem "rake-compiler-dock", "~> 1.11.0" #, git: "https://github.com/rake-compiler/rake-compiler-dock" gem "rspec", "~> 3.5" + gem "webrick", "~> 1.8" # "bigdecimal" is a gem on ruby-3.4+ and it's optional for ruby-pg. # Specs should succeed without it, but 4 examples are then excluded. # With bigdecimal commented out here, corresponding tests are omitted on ruby-3.4+ but are executed on ruby < 3.4. diff --git a/ext/extconf.rb b/ext/extconf.rb index 4e293b080..b406059a7 100644 --- a/ext/extconf.rb +++ b/ext/extconf.rb @@ -316,6 +316,7 @@ module PG src + " int con(){ return PGRES_PIPELINE_SYNC; }" end have_func 'PQsetChunkedRowsMode', 'libpq-fe.h' # since PostgreSQL-17 +have_func 'PQsetAuthDataHook', 'libpq-fe.h' # since PostgreSQL-18 have_func 'timegm' have_func 'rb_io_wait' # since ruby-3.0 have_func 'rb_io_descriptor' # since ruby-3.1 diff --git a/ext/gvl_wrappers.h b/ext/gvl_wrappers.h index b3526d184..cc32663ce 100644 --- a/ext/gvl_wrappers.h +++ b/ext/gvl_wrappers.h @@ -281,17 +281,34 @@ FOR_EACH_BLOCKING_FUNCTION( DEFINE_GVL_STUB_DECL ); * Definitions of callback functions and their parameters */ +#define FOR_EACH_PARAM_OF_auth_data_hook_proxy(param) \ + param(PGauthData, type) \ + param(PGconn *, conn) + #define FOR_EACH_PARAM_OF_notice_processor_proxy(param) \ param(void *, arg) #define FOR_EACH_PARAM_OF_notice_receiver_proxy(param) \ param(void *, arg) +#ifdef HAVE_PQSETAUTHDATAHOOK + /* function( name, void_or_nonvoid, returntype, lastparamtype, lastparamname ) */ #define FOR_EACH_CALLBACK_FUNCTION(function) \ + function(auth_data_hook_proxy, GVL_TYPE_NONVOID, int, void *, data) \ function(notice_processor_proxy, GVL_TYPE_VOID, void, const char *, message) \ function(notice_receiver_proxy, GVL_TYPE_VOID, void, const PGresult *, result) \ +#else + +/* function( name, void_or_nonvoid, returntype, lastparamtype, lastparamname ) */ +#define FOR_EACH_CALLBACK_FUNCTION(function) \ + function(notice_processor_proxy, GVL_TYPE_VOID, void, const char *, message) \ + function(notice_receiver_proxy, GVL_TYPE_VOID, void, const PGresult *, result) \ + +#endif + FOR_EACH_CALLBACK_FUNCTION( DEFINE_GVL_STUB_DECL ); + #endif /* end __gvl_wrappers_h */ diff --git a/ext/pg.c b/ext/pg.c index db6d5cbfa..229a477df 100644 --- a/ext/pg.c +++ b/ext/pg.c @@ -682,6 +682,7 @@ Init_pg_ext(void) /* Initialize the main extension classes */ init_pg_connection(); + init_pg_auth_hooks(); init_pg_result(); init_pg_errors(); init_pg_type_map(); diff --git a/ext/pg.h b/ext/pg.h index 290333db5..5b8f98ce1 100644 --- a/ext/pg.h +++ b/ext/pg.h @@ -289,6 +289,7 @@ extern VALUE pg_typemap_all_strings; void Init_pg_ext _(( void )); void init_pg_connection _(( void )); +void init_pg_auth_hooks _(( void )); void init_pg_result _(( void )); void init_pg_errors _(( void )); void init_pg_type_map _(( void )); @@ -375,6 +376,9 @@ rb_encoding * pg_get_pg_encname_as_rb_encoding _(( const char * )); const char * pg_get_rb_encoding_as_pg_encoding _(( rb_encoding * )); rb_encoding *pg_conn_enc_get _(( PGconn * )); +#ifdef HAVE_PQSETAUTHDATAHOOK +int auth_data_hook_proxy(PGauthData type, PGconn *conn, void *data); +#endif void notice_receiver_proxy(void *arg, const PGresult *result); void notice_processor_proxy(void *arg, const char *message); diff --git a/ext/pg_auth_hooks.c b/ext/pg_auth_hooks.c new file mode 100644 index 000000000..23cdd03e4 --- /dev/null +++ b/ext/pg_auth_hooks.c @@ -0,0 +1,388 @@ +/* + * pg_auth_hooks.c - Auth hooks for PG module + * $Id$ + * + */ + +#include "pg.h" + +#if HAVE_PQSETAUTHDATAHOOK + +#ifdef TRUFFLERUBY +static VALUE auth_data_hook; +#else +/* + * On Ruby verisons which support Ractors we store the global callback once + * per Ractor. + */ +#include "ruby/ractor.h" +static rb_ractor_local_key_t auth_data_hook_key; +#endif + +static void +auth_data_hook_init(void) +{ +#ifdef TRUFFLERUBY + auth_data_hook = Qnil; + rb_gc_register_address(&auth_data_hook); +#else + auth_data_hook_key = rb_ractor_local_storage_value_newkey(); +#endif +} + +static VALUE +auth_data_hook_get(void) +{ +#ifdef TRUFFLERUBY + return auth_data_hook; +#else + VALUE hook = Qnil; + rb_ractor_local_storage_value_lookup(auth_data_hook_key, &hook); + return hook; +#endif +} + +static void +auth_data_hook_set(VALUE hook) +{ +#ifdef TRUFFLERUBY + auth_data_hook = hook; +#else + rb_ractor_local_storage_value_set(auth_data_hook_key, hook); +#endif +} + +static VALUE rb_cPromptOAuthDevice; +static VALUE rb_cOAuthBearerRequest; + +/* + * Document-class: PG::PromptOAuthDevice + */ + +typedef struct { + PGpromptOAuthDevice *prompt; +} t_pg_prompt_oauth_device; + +static size_t +pg_prompt_oauth_device_memsize(const void *_this) +{ + return sizeof(t_pg_prompt_oauth_device); +} + +static const rb_data_type_t pg_prompt_oauth_device_type = { + "PG::PromptOAuthDevice", + { + NULL, + RUBY_TYPED_DEFAULT_FREE, + pg_prompt_oauth_device_memsize, + NULL, + }, + 0, + 0, + RUBY_TYPED_WB_PROTECTED | RUBY_TYPED_FREE_IMMEDIATELY, +}; + +static t_pg_prompt_oauth_device * +pg_get_prompt_oauth_device_safe(VALUE self) +{ + t_pg_prompt_oauth_device *this; + + TypedData_Get_Struct(self, t_pg_prompt_oauth_device, &pg_prompt_oauth_device_type, this); + + if (!this->prompt) + rb_raise(rb_ePGerror, "data cannot be accessed after callback has completed"); + + return this; +} + +/* + * call-seq: + * prompt.verification_uri -> String + */ +static VALUE +pg_prompt_oauth_device_verification_uri(VALUE self) +{ + t_pg_prompt_oauth_device *this = pg_get_prompt_oauth_device_safe(self); + + if (!this->prompt->verification_uri) + rb_raise(rb_ePGerror, "internal error: verification_uri is missing"); + + return rb_str_new_cstr(this->prompt->verification_uri); +} + +/* + * call-seq: + * prompt.user_code -> String + */ +static VALUE +pg_prompt_oauth_device_user_code(VALUE self) +{ + t_pg_prompt_oauth_device *this = pg_get_prompt_oauth_device_safe(self); + + if (!this->prompt->user_code) + rb_raise(rb_ePGerror, "internal error: user_code is missing"); + + return rb_str_new_cstr(this->prompt->user_code); +} + +/* + * call-seq: + * prompt.verification_uri_complete -> String | nil + */ +static VALUE +pg_prompt_oauth_device_verification_uri_complete(VALUE self) +{ + t_pg_prompt_oauth_device *this = pg_get_prompt_oauth_device_safe(self); + + return this->prompt->verification_uri_complete ? rb_str_new_cstr(this->prompt->verification_uri_complete) : Qnil; +} + +/* + * call-seq: + * prompt.expires_in -> Integer + */ +static VALUE +pg_prompt_oauth_device_expires_in(VALUE self) +{ + t_pg_prompt_oauth_device *this = pg_get_prompt_oauth_device_safe(self); + + return INT2FIX(this->prompt->expires_in); +} + +/* + * Document-class: PG::OAuthBearerRequest + */ + +typedef struct { + PGoauthBearerRequest *request; +} t_pg_oauth_bearer_request; + +static size_t +pg_oauth_bearer_request_memsize(const void *_this) +{ + return sizeof(t_pg_oauth_bearer_request); +} + +static const rb_data_type_t pg_oauth_bearer_request_type = { + "PG::OAuthBearerRequest", + { + NULL, + RUBY_TYPED_DEFAULT_FREE, + pg_oauth_bearer_request_memsize, + NULL, + }, + 0, + 0, + RUBY_TYPED_WB_PROTECTED | RUBY_TYPED_FREE_IMMEDIATELY, +}; + +static t_pg_oauth_bearer_request * +pg_get_oauth_bearer_request_safe(VALUE self) +{ + t_pg_oauth_bearer_request *this; + + TypedData_Get_Struct(self, t_pg_oauth_bearer_request, &pg_oauth_bearer_request_type, this); + + if (!this->request) + rb_raise(rb_ePGerror, "data cannot be accessed after callback has completed"); + + return this; +} + +/* + * call-seq: + * prompt.openid_configuration -> String + */ +static VALUE +pg_oauth_bearer_request_openid_configuration(VALUE self) +{ + t_pg_oauth_bearer_request *this = pg_get_oauth_bearer_request_safe(self); + + if (!this->request->openid_configuration) + rb_raise(rb_ePGerror, "internal error: openid_configuration is missing"); + + return rb_str_new_cstr(this->request->openid_configuration); +} + +/* + * call-seq: + * request.scope -> String | nil + */ +static VALUE +pg_oauth_bearer_request_scope(VALUE self) +{ + t_pg_oauth_bearer_request *this = pg_get_oauth_bearer_request_safe(self); + + return this->request->scope ? rb_str_new_cstr(this->request->scope) : Qnil; +} + +/* + * call-seq: + * request.token = token + * + * See also #token + */ +static VALUE +pg_oauth_bearer_request_token_set(VALUE self, VALUE token) +{ + t_pg_oauth_bearer_request *this = pg_get_oauth_bearer_request_safe(self); + + /* This can throw an exception so needs to be done before free() */ + char *token_cstr = NIL_P(token) ? NULL : strdup(StringValueCStr(token)); + + if (this->request->token) + free(this->request->token); + + this->request->token = token_cstr; + + return token; +} + +/* + * call-seq: + * request.token -> String | nil + * + * See also #token= + */ +static VALUE +pg_oauth_bearer_request_token_get(VALUE self) +{ + t_pg_oauth_bearer_request *this = pg_get_oauth_bearer_request_safe(self); + + return this->request->token ? rb_str_new_cstr(this->request->token) : Qnil; +} + +static void +oauth_bearer_request_cleanup(PGconn *_conn, struct PGoauthBearerRequest *request) +{ + if (request->token) + free(request->token); +} + +static VALUE +call_auth_data_hook(VALUE data) +{ + VALUE proc = auth_data_hook_get(); + + return rb_funcall(proc, rb_intern("call"), 1, data); +} + +static VALUE +prompt_oauth_device_hook_cleanup(VALUE self, VALUE ex) +{ + t_pg_prompt_oauth_device *this = pg_get_prompt_oauth_device_safe(self); + + this->prompt = NULL; + + rb_exc_raise(ex); +} + +static VALUE +oauth_bearer_request_hook_cleanup(VALUE self, VALUE ex) +{ + t_pg_oauth_bearer_request *this = pg_get_oauth_bearer_request_safe(self); + + if (this->request->token) + free(this->request->token); + this->request->token = NULL; + + this->request = NULL; + + rb_exc_raise(ex); +} + +/* + * Auth data proxy function -- delegate the callback to the + * currently-registered Ruby auth_data_hook object. + */ +int +auth_data_hook_proxy(PGauthData type, PGconn *_conn, void *data) +{ + VALUE proc = auth_data_hook_get(), ret = Qnil; + + if (proc != Qnil) { + if (type == PQAUTHDATA_PROMPT_OAUTH_DEVICE) { + t_pg_prompt_oauth_device *prompt; + + VALUE v_prompt = TypedData_Make_Struct(rb_cPromptOAuthDevice, t_pg_prompt_oauth_device, &pg_prompt_oauth_device_type, prompt); + + prompt->prompt = data; + + ret = rb_rescue(call_auth_data_hook, v_prompt, prompt_oauth_device_hook_cleanup, v_prompt); + + prompt->prompt = NULL; + } else if (type == PQAUTHDATA_OAUTH_BEARER_TOKEN) { + t_pg_oauth_bearer_request *request; + + VALUE v_request = TypedData_Make_Struct(rb_cOAuthBearerRequest, t_pg_oauth_bearer_request, &pg_oauth_bearer_request_type, request); + + request->request = data; + request->request->cleanup = oauth_bearer_request_cleanup; + + ret = rb_rescue(call_auth_data_hook, v_request, oauth_bearer_request_hook_cleanup, v_request); + + request->request = NULL; + } + } + + return RTEST(ret); +} + +/* + * Document-method: PG.set_auth_data_hook + * + * call-seq: + * PG.set_auth_data_hook {|data| ... } -> Proc + * + * If you pass no arguments, it will reset the handler to the default. + */ +static VALUE +pg_s_set_auth_data_hook(VALUE _self) +{ + PQsetAuthDataHook(gvl_auth_data_hook_proxy); // TODO: Add some safeguards? + + VALUE old_proc = auth_data_hook_get(), proc; + + if (rb_block_given_p()) { + proc = rb_block_proc(); + } else { + /* if no block is given, set back to default */ + proc = Qnil; + } + + auth_data_hook_set(proc); + + return old_proc; +} + +void +init_pg_auth_hooks(void) +{ + auth_data_hook_init(); + + /* rb_mPG = rb_define_module("PG") */ + + rb_define_singleton_method(rb_mPG, "set_auth_data_hook", pg_s_set_auth_data_hook, 0); + + rb_cPromptOAuthDevice = rb_define_class_under(rb_mPG, "PromptOAuthDevice", rb_cObject); + rb_undef_alloc_func(rb_cPromptOAuthDevice); + + rb_define_method(rb_cPromptOAuthDevice, "verification_uri", pg_prompt_oauth_device_verification_uri, 0); + rb_define_method(rb_cPromptOAuthDevice, "user_code", pg_prompt_oauth_device_user_code, 0); + rb_define_method(rb_cPromptOAuthDevice, "verification_uri_complete", pg_prompt_oauth_device_verification_uri_complete, 0); + rb_define_method(rb_cPromptOAuthDevice, "expires_in", pg_prompt_oauth_device_expires_in, 0); + + rb_cOAuthBearerRequest = rb_define_class_under(rb_mPG, "OAuthBearerRequest", rb_cObject); + rb_undef_alloc_func(rb_cOAuthBearerRequest); + + rb_define_method(rb_cOAuthBearerRequest, "openid_configuration", pg_oauth_bearer_request_openid_configuration, 0); + rb_define_method(rb_cOAuthBearerRequest, "scope", pg_oauth_bearer_request_scope, 0); + rb_define_method(rb_cOAuthBearerRequest, "token=", pg_oauth_bearer_request_token_set, 1); + rb_define_method(rb_cOAuthBearerRequest, "token", pg_oauth_bearer_request_token_get, 0); +} + +#else + +void init_pg_auth_hooks(void) {} + +#endif diff --git a/ext/pg_connection.c b/ext/pg_connection.c index 696f26f6f..981678330 100644 --- a/ext/pg_connection.c +++ b/ext/pg_connection.c @@ -940,6 +940,19 @@ pgconn_socket(VALUE self) return INT2NUM(sd); } +#ifdef _WIN32 +#define is_socket(fd) rb_w32_is_socket(fd) +#else +static int +is_socket(int fd) +{ + struct stat sbuf; + + if (fstat(fd, &sbuf) < 0) + rb_sys_fail("fstat(2)"); + return S_ISSOCK(sbuf.st_mode); +} +#endif VALUE pg_wrap_socket_io(int sd, VALUE self, VALUE *p_socket_io, int *p_ruby_sd) @@ -958,7 +971,7 @@ pg_wrap_socket_io(int sd, VALUE self, VALUE *p_socket_io, int *p_ruby_sd) *p_ruby_sd = ruby_sd = sd; #endif - cSocket = rb_const_get(rb_cObject, rb_intern("BasicSocket")); + cSocket = rb_const_get(rb_cObject, rb_intern(is_socket(ruby_sd) ? "BasicSocket" : "IO")); socket_io = rb_funcall( cSocket, rb_intern("for_fd"), 1, INT2NUM(ruby_sd)); /* Disable autoclose feature */ diff --git a/spec/helpers.rb b/spec/helpers.rb index 5935e2034..bbcc1faf8 100644 --- a/spec/helpers.rb +++ b/spec/helpers.rb @@ -194,6 +194,7 @@ class PostgresServer attr_reader :port attr_reader :conninfo attr_reader :unix_socket + attr_reader :version ### Set up a PostgreSQL database instance for testing. def initialize(name, port: 23456, postgresql_conf: '') @@ -205,6 +206,7 @@ def initialize(name, port: 23456, postgresql_conf: '') @pgdata = @test_dir + 'data' @logfile = @test_dir + 'setup.log' @pg_bindir = pg_bindir + @version = pg_version @unix_socket = @test_dir.to_s @conninfo = "host=localhost port=#{@port} dbname=test sslrootcert=#{@pgdata + 'ruby-pg-ca-cert'} sslcert=#{@pgdata + 'ruby-pg-client-cert'} sslkey=#{@pgdata + 'ruby-pg-client-key'}" @@ -267,8 +269,13 @@ def setup_cluster(postgresql_conf) ssl_cert_file = 'ruby-pg-server-cert' ssl_key_file = 'ruby-pg-server-key' fsync = off - #{postgresql_conf} EOT + if @version >= 18 + fd.puts <<~EOT + oauth_validator_libraries = '#{TEST_DIRECTORY}/spec/oauth/dummy_validator' + EOT + end + fd.puts postgresql_conf end # Enable MD5 authentication in hba config @@ -278,6 +285,12 @@ def setup_cluster(postgresql_conf) # TYPE DATABASE USER ADDRESS METHOD host all testusermd5 ::1/128 md5 EOT + if @version >= 18 + fd.puts <<~EOT + host all testuseroauth 127.0.0.1/32 oauth scope=test issuer="http://localhost:#{@port + 3}" + host all testuseroauth ::1/32 oauth scope=test issuer="http://localhost:#{@port + 3}" + EOT + end fd.puts hba_content end @@ -340,6 +353,10 @@ def pg_bindir rescue nil end + + def pg_version + `#{pg_bin_path("pg_ctl")} --version`[/pg_ctl \(PostgreSQL\) (\d+)/, 1]&.to_i + end end class CertGenerator @@ -682,6 +699,7 @@ def set_etc_hosts(hostaddr, hostname) config.filter_run_excluding( :postgresql_12 ) if PG.library_version < 120000 config.filter_run_excluding( :postgresql_14 ) if PG.library_version < 140000 config.filter_run_excluding( :postgresql_17 ) if PG.library_version < 170000 + config.filter_run_excluding( :postgresql_18 ) if PG.library_version < 180000 config.filter_run_excluding( :unix_socket ) if RUBY_PLATFORM=~/mingw|mswin/i config.filter_run_excluding( :scheduler ) if RUBY_VERSION < "3.0" || (RUBY_PLATFORM =~ /mingw|mswin/ && RUBY_VERSION < "3.1") || !Fiber.respond_to?(:scheduler) config.filter_run_excluding( :scheduler_address_resolve ) if RUBY_VERSION < "3.1" diff --git a/spec/oauth/Makefile b/spec/oauth/Makefile new file mode 100644 index 000000000..508292cea --- /dev/null +++ b/spec/oauth/Makefile @@ -0,0 +1,8 @@ +MODULES = dummy_validator +PGFILEDESC = "dummy_validator - dummy OAuth validator" + +OBJS = $(WIN32RES) + +PG_CONFIG = pg_config +PGXS := $(shell $(PG_CONFIG) --pgxs) +include $(PGXS) diff --git a/spec/oauth/dummy_validator.c b/spec/oauth/dummy_validator.c new file mode 100644 index 000000000..33018392a --- /dev/null +++ b/spec/oauth/dummy_validator.c @@ -0,0 +1,29 @@ +#include "postgres.h" +#include "fmgr.h" +#include "libpq/oauth.h" + +PG_MODULE_MAGIC; + +static bool +validate_token(const ValidatorModuleState *state, + const char *token, const char *role, + ValidatorModuleResult *res) +{ + if (strcmp(token, "yes") == 0) + { + res->authorized = true; + res->authn_id = pstrdup(role); + } + return true; +} + +static const OAuthValidatorCallbacks validator_callbacks = { + PG_OAUTH_VALIDATOR_MAGIC, + .validate_cb = validate_token +}; + +const OAuthValidatorCallbacks * +_PG_oauth_validator_module_init(void) +{ + return &validator_callbacks; +} diff --git a/spec/pg/connection_spec.rb b/spec/pg/connection_spec.rb index 9ac2a2475..ca038bdd1 100644 --- a/spec/pg/connection_spec.rb +++ b/spec/pg/connection_spec.rb @@ -2998,4 +2998,147 @@ def wait_check_socket(conn) .to raise_error(TypeError) end end + + describe "PG.set_auth_data_hook", :postgresql_18 do + before :all do + skip "requires a PostgreSQL 18 cluster" unless $pg_server.version >= 18 + + system "make", "-s", "-C", (TEST_DIRECTORY + "spec/oauth").to_s + raise "Building OAuth validator library failed!" unless $?.success? + + require 'webrick' + + PG.connect(@conninfo) do |conn| + conn.exec("DROP USER IF EXISTS testuseroauth") + conn.exec("CREATE USER testuseroauth") + end + end + + before :each do + @old_env, ENV["PGOAUTHDEBUG"] = ENV["PGOAUTHDEBUG"], "UNSAFE" + end + + def start_fake_oauth(port) + server = WEBrick::HTTPServer.new(Port: port, Logger: WEBrick::Log.new(nil, WEBrick::BasicLog::WARN)) + server.mount_proc("/.well-known/openid-configuration") do |req, res| + res["Content-Type"] = "application/json" + res.body = %!{"issuer":"http://localhost:#{port}","token_endpoint":"http://localhost:#{port}/token","device_authorization_endpoint":"http://localhost:#{@port + 3}/devauth"}! + end + server.mount_proc("/devauth") do |req, res| + res["Content-Type"] = "application/json" + res.body = %!{"device_code":"42","user_code":"666","verification_uri":"http://localhost:#{port}/verify","expires_in":60}! + end + server.mount_proc("/token") do |req, res| + res["Content-Type"] = "application/json" + res.body = %!{"access_token":"yes","token_type":""}! + end + Thread.new { server.start } + server + end + + it "should work with no hook" do + oauth_server = start_fake_oauth(@port + 3) + + begin + PG.connect("host=localhost port=#{@port} dbname=test user=testuseroauth oauth_issuer=http://localhost:#{@port + 3} oauth_client_id=foo") do |conn| + conn.exec("SELECT 1") + end + rescue PG::ConnectionBad => e + if e.message =~ /no OAuth flows are available/ + skip "requires libpq-oauth to be installed" + end + raise + ensure + oauth_server.shutdown + end + end + + it "should call prompt oauth device hook" do + oauth_server = start_fake_oauth(@port + 3) + + verification_uri, user_code, verification_uri_complete, expires_in = nil, nil, nil, nil + + PG.set_auth_data_hook do |data| + case data + when PG::PromptOAuthDevice + verification_uri = data.verification_uri + user_code = data.user_code + verification_uri_complete = data.verification_uri_complete + expires_in = data.expires_in + true + end + end + + begin + PG.connect("host=localhost port=#{@port} dbname=test user=testuseroauth oauth_issuer=http://localhost:#{@port + 3} oauth_client_id=foo") do |conn| + conn.exec("SELECT 1") + end + rescue PG::ConnectionBad => e + if e.message =~ /no OAuth flows are available/ + skip "requires libpq-oauth to be installed" + end + raise + ensure + oauth_server.shutdown + end + + expect(verification_uri).to eq("http://localhost:#{@port + 3}/verify") + expect(user_code).to eq("666") + expect(verification_uri_complete).to eq(nil) + expect(expires_in).to eq(60) + end + + it "should call oauth bearer reuqest hook" do + openid_configuration, scope = nil, nil + + PG.set_auth_data_hook do |data| + case data + when PG::OAuthBearerRequest + openid_configuration = data.openid_configuration + scope = data.scope + data.token = "yes" + true + end + end + + PG.connect("host=localhost port=#{@port} dbname=test user=testuseroauth oauth_issuer=http://localhost:#{@port + 3} oauth_client_id=foo") do |conn| + conn.exec("SELECT 1") + end + + expect(openid_configuration).to eq("http://localhost:#{@port + 3}/.well-known/openid-configuration") + expect(scope).to eq("test") + end + + it "should reset the hook when called without block" do + oauth_server = start_fake_oauth(@port + 3) + + PG.set_auth_data_hook do |data| + raise "broken hook" + end + + expect do + PG.connect("host=localhost port=#{@port} dbname=test user=testuseroauth oauth_issuer=http://localhost:#{@port + 3} oauth_client_id=foo") {} + end.to raise_error("broken hook") + + PG.set_auth_data_hook + + begin + PG.connect("host=localhost port=#{@port} dbname=test user=testuseroauth oauth_issuer=http://localhost:#{@port + 3} oauth_client_id=foo") do |conn| + conn.exec("SELECT 1") + end + rescue PG::ConnectionBad => e + if e.message =~ /no OAuth flows are available/ + skip "requires libpq-oauth to be installed" + end + raise + ensure + oauth_server.shutdown + end + end + + after :each do + PG.set_auth_data_hook + ENV["PGOAUTHDEBUG"] = @old_env + end + end end