diff --git a/internal/webserver/handler_lists.go b/internal/webserver/handler_lists.go index 93d4f3a..184b4c9 100644 --- a/internal/webserver/handler_lists.go +++ b/internal/webserver/handler_lists.go @@ -130,9 +130,11 @@ func (app *Application) Lists(w http.ResponseWriter, r *http.Request) { app.error_400_with_message(w, "List ID must be a number") return } - // XXX: Check that the list exists - // Need to modify signature to return an error, because it might be ErrNoRows - list := app.Profile.GetListById(ListID(_list_id)) + list, err := app.Profile.GetListById(ListID(_list_id)) + if err != nil { + app.error_404(w) + return + } req_with_ctx := r.WithContext(add_list_to_context(r.Context(), list)) http.StripPrefix(fmt.Sprintf("/%d", list.ID), http.HandlerFunc(app.ListDetail)).ServeHTTP(w, req_with_ctx) return diff --git a/internal/webserver/server_test.go b/internal/webserver/server_test.go index 4a267ba..965987a 100644 --- a/internal/webserver/server_test.go +++ b/internal/webserver/server_test.go @@ -597,8 +597,8 @@ func TestListsIndex(t *testing.T) { root, err := html.Parse(resp.Body) require.NoError(err) - // Check that there's 2 Lists - assert.Len(t, cascadia.QueryAll(root, selector(".users-list-preview")), 2) + // Check that there's at least 2 Lists + assert.True(t, len(cascadia.QueryAll(root, selector(".users-list-preview"))) >= 2) } func TestListDetail(t *testing.T) { @@ -620,6 +620,11 @@ func TestListDetail(t *testing.T) { assert.Len(cascadia.QueryAll(root1, selector(".timeline > .tweet")), 3) } +func TestListDetailDoesntExist(t *testing.T) { + resp := do_request(httptest.NewRequest("GET", "/lists/2523478", nil)) + require.Equal(t, resp.StatusCode, 404) +} + func TestListDetailInvalidId(t *testing.T) { resp := do_request(httptest.NewRequest("GET", "/lists/asd", nil)) require.Equal(t, resp.StatusCode, 400) diff --git a/pkg/persistence/list_queries.go b/pkg/persistence/list_queries.go index d3589a2..8c2cb5b 100644 --- a/pkg/persistence/list_queries.go +++ b/pkg/persistence/list_queries.go @@ -73,13 +73,15 @@ func (p Profile) DeleteListUser(list_id ListID, user_id UserID) { } } -func (p Profile) GetListById(list_id ListID) List { +func (p Profile) GetListById(list_id ListID) (List, error) { var ret List err := p.DB.Get(&ret, `select rowid, is_online, online_list_id, name from lists where rowid = ?`, list_id) - if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return List{}, ErrNotInDatabase{"List", list_id} + } else if err != nil { panic(err) } - return ret + return ret, nil } func (p Profile) GetListUsers(list_id ListID) []User { diff --git a/pkg/persistence/list_queries_test.go b/pkg/persistence/list_queries_test.go index e217c69..4a076fe 100644 --- a/pkg/persistence/list_queries_test.go +++ b/pkg/persistence/list_queries_test.go @@ -27,7 +27,8 @@ func TestSaveAndLoadOfflineList(t *testing.T) { require.NotEqual(l.ID, ListID(0)) // ID should be assigned when it's saved // Check it comes back the same - new_l := profile.GetListById(l.ID) + new_l, err := profile.GetListById(l.ID) + require.NoError(err) assert.Equal(l.ID, new_l.ID) assert.Equal(l.IsOnline, new_l.IsOnline) assert.Equal(l.Name, new_l.Name) @@ -50,7 +51,8 @@ func TestRenameOfflineList(t *testing.T) { profile.SaveList(&l) // Rename should be effective - new_l := profile.GetListById(l.ID) + new_l, err := profile.GetListById(l.ID) + require.NoError(err) assert.Equal(l.ID, new_l.ID) assert.Equal(l.IsOnline, new_l.IsOnline) assert.Equal(l.Name, new_l.Name) @@ -70,7 +72,8 @@ func TestSaveAndLoadOnlineList(t *testing.T) { require.NotEqual(l.ID, ListID(0)) // ID should be assigned when it's saved // Check it comes back the same - new_l := profile.GetListById(l.ID) + new_l, err := profile.GetListById(l.ID) + require.NoError(err) assert.Equal(l.ID, new_l.ID) assert.Equal(l.IsOnline, new_l.IsOnline) assert.Equal(l.OnlineID, new_l.OnlineID) // Check OnlineID for online lists @@ -94,7 +97,8 @@ func TestRenameOnlineList(t *testing.T) { profile.SaveList(&l) // Rename should be effective - new_l := profile.GetListById(l.ID) + new_l, err := profile.GetListById(l.ID) + require.NoError(err) assert.Equal(l.ID, new_l.ID) assert.Equal(l.IsOnline, new_l.IsOnline) assert.Equal(l.OnlineID, new_l.OnlineID) // Check OnlineID for online lists