sorted_forward_list.hh

00001 // This file is a part of Aurelia.
00002 // Copyright (C) 2010  Valentin David
00003 // Copyright (C) 2010  University of Bergen
00004 //
00005 // This program is free software: you can redistribute it and/or modify
00006 // it under the terms of the GNU General Public License as published by
00007 // the Free Software Foundation, either version 3 of the License, or
00008 // (at your option) any later version.
00009 //
00010 // This program is distributed in the hope that it will be useful,
00011 // but WITHOUT ANY WARRANTY; without even the implied warranty of
00012 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00013 // GNU General Public License for more details.
00014 //
00015 // You should have received a copy of the GNU General Public License
00016 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
00017 
00018 #ifndef __SORTED_FORWARD_LIST_HH
00019 # define __SORTED_FORWARD_LIST_HH
00020 
00021 # include <functional>
00022 # include <atomic>
00023 # include <cassert>
00024 # include "simple_stack.hh"
00025 # include "marked_pointer.hh"
00026 
00027 namespace lockfree {
00028   template <typename Elt,
00029             typename Eq = std::equal_to<Elt>,
00030             typename Compare = std::less<Elt>,
00031             typename Alloc = std::allocator<Elt>,
00032             typename Multi = typename std::true_type
00033             >
00034   class sorted_forward_list {
00035   private:
00036     struct node;
00037     struct node_base {
00038     private:
00039       std::atomic<node_base*> next;
00040     public:
00041       node_base(node_base* next): next(next) {}
00042       node_base* get_next() const {
00043         return next.load();
00044       }
00045       bool set_next(node_base* expected, node_base* newnext) {
00046         return next.compare_exchange_strong(expected, newnext);
00047       }
00048       void force_set_next(node_base* newnext) {
00049         node_base *old;
00050         do
00051           old = next.load();
00052         while (!next.compare_exchange_strong(old, newnext));
00053       }
00054       const Elt& operator*() const {
00055         return static_cast<const node*>(this)->operator*();
00056       }
00057 
00058       node_base& operator=(node_base&& other) {
00059         std::swap(next, other.next);
00060         return *this;
00061       }
00062     };
00063     struct node: public node_base {
00064     private:
00065       Elt e;
00066     public:
00067       node(Elt&& e, node_base* next): node_base(next), e(std::move(e)) {}
00068       node& operator=(node&& other) {
00069         std::swap(e, other.e);
00070         this->node_base::operator=((node_base&&)std::move(other));
00071         return *this;
00072       }
00073       const Elt& operator*() const {
00074         return e;
00075       }
00076     };
00077 
00078     typedef typename Alloc::template rebind<node>::other node_allocator_t;
00079     typedef simple_stack<node*,Alloc> sstack_t;
00080 
00081     node_allocator_t    node_alloc;
00082     Eq                  eq;
00083     Compare             compare;
00084     node_base           head;
00085     sstack_t            free;
00086     std::atomic<size_t> _size;
00087 
00088   public:
00089 
00090     size_t size() const {
00091       return _size.load();
00092     }
00093 
00094     sorted_forward_list(Eq eq = Eq(), Compare compare = Compare(),
00095                         Alloc alloc = Alloc())
00096       : node_alloc(alloc), eq(eq), compare(compare), head(NULL), free(alloc),
00097         _size(0) {}
00098 
00099     struct iterator;
00100 
00101     struct const_iterator: private iterator {
00102       explicit const_iterator(node_base* n): iterator(n) {}
00103       const_iterator(const iterator& other): iterator(other) {}
00104       const_iterator(const const_iterator& other): iterator(other) {}
00105 
00106       node_base* get_node() const {
00107         return this->iterator::get_node();
00108       }
00109 
00110       const Elt& operator*() const {
00111         return this->iterator::operator*();
00112       }
00113       const_iterator& operator++() {
00114         this->iterator::operator++();
00115         return *this;
00116       }
00117       const_iterator& operator++(int) {
00118         const_iterator tmp(*this);
00119         this->iterator::operator++();
00120         return tmp;
00121       }
00122       bool operator==(const const_iterator& other) const {
00123         return this->iterator::operator==(other);
00124       }
00125       bool operator!=(const const_iterator& other) const {
00126         return this->iterator::operator!=(other);
00127       }
00128       bool operator==(const iterator& other) const {
00129         return this->iterator::operator==(other);
00130       }
00131       bool operator!=(const iterator& other) const {
00132         return this->iterator::operator!=(other);
00133       }
00134     };
00135 
00136     struct iterator {
00137     private:
00138       node_base* n;
00139 
00140     public:
00141       explicit iterator(node_base* n): n(n) {}
00142       iterator(const iterator& other): n(other.n) {}
00143 
00144       const Elt& operator*() const {
00145         return **n;
00146       }
00147       iterator& operator++() {
00148         n = n->get_next();
00149         return *this;
00150       }
00151 
00152       iterator& operator++(int) {
00153         iterator tmp(*this);
00154         n = n->get_next();
00155         return tmp;
00156       }
00157 
00158       bool operator==(const iterator& other) const {
00159         return n == other.n;
00160       }
00161 
00162       bool operator!=(const iterator& other) const {
00163         return n != other.n;
00164       }
00165 
00166       node_base* get_node() const {
00167         return n;
00168       }
00169     };
00170 
00171     const_iterator begin() const {
00172       return const_iterator(head.get_next());
00173     }
00174 
00175     const_iterator end() const {
00176       return const_iterator(NULL);
00177     }
00178 
00179     iterator begin() {
00180       return iterator(head.get_next());
00181     }
00182 
00183     iterator end() {
00184       return iterator(NULL);
00185     }
00186 
00187     std::pair<iterator,bool>
00188     insert_after(Elt&& elt, const const_iterator& hint,
00189                  bool once = !Multi::value) {
00190       return insert_after(std::move(elt), hint.get_node(), once);
00191     }
00192 
00193   private:
00194     std::pair<iterator,bool>
00195     insert_after(Elt&& elt, node_base* hint,
00196                  bool once = !Multi::value) {
00197       node* n = node_alloc.allocate(1);
00198       node_alloc.construct(n, std::move(elt), (node_base*)NULL);
00199 
00200       std::pair<node_base*,node_base*> found;
00201       while (true) {
00202         found = search(hint, **n);
00203         if (once)
00204           if ((found.second==NULL)?false:eq(**found.second, **n)) {
00205             std::swap(*const_cast<Elt*>(&**n), elt);
00206             node_alloc.destroy(n);
00207             node_alloc.deallocate(n, 1);
00208             return std::pair<iterator,bool>
00209               (iterator(static_cast<node*>(found.second)),
00210                false);
00211           }
00212 
00213         n->force_set_next(found.second);
00214         if (found.first->set_next(found.second, n)) {
00215           _size++;
00216           return std::pair<iterator,bool>(iterator(n), true);
00217         }
00218       };
00219     }
00220 
00221   public:
00222     std::pair<iterator,bool> insert(Elt&& elt) {
00223       return insert_after(std::move(elt), &head);
00224     }
00225 
00226   private:
00227     std::pair<node_base*, node_base*>
00228     search(node_base* hint, const Elt& elt, node_base* exact = NULL) {
00229       typedef std::pair<node_base*, node_base*> rt;
00230 
00231       node_base* before = NULL;
00232       node_base* before_next = NULL;
00233 
00234       while (true) {
00235         node_base* t = hint;
00236         node_base* t_next = t->get_next();
00237 
00238         do {
00239           if (!details::is_marked(t_next)) {
00240             before = t;
00241             before_next = t_next;
00242           }
00243 
00244           if (NULL == (t = details::unmarked(t_next)))
00245             break ;
00246           t_next = t->get_next();
00247         } while (details::is_marked(t_next)?true:
00248                  (exact==NULL)?(compare(**t, elt)?true:
00249                                 (!compare(elt, **t)?!eq(**t,elt):false)):
00250                  (exact==t)?false:!compare(elt, **t));
00251 
00252         if (before_next == t) {
00253           if ((t == NULL)?false:details::is_marked(t->get_next()))
00254             continue ;
00255           else
00256             return rt(before, t);
00257         }
00258 
00259         if (before->set_next(before_next, t)) {
00260           if ((t == NULL)?false:!details::is_marked(t->get_next()))
00261             continue ;
00262           else
00263             return rt(before, t);
00264         }
00265       }
00266     }
00267 
00268     bool erase(node_base* item, node_base* hint) {
00269       std::pair<node_base*, node_base*> found;
00270       node_base* next;
00271 
00272       while (true) {
00273         found = search(hint, **item, item);
00274         if ((found.second == NULL)?true:(found.second != item)) {
00275           return false;
00276         }
00277         next = found.second->get_next();
00278         if (!details::is_marked(next))
00279           if (found.second->set_next(next, details::marked(next))) {
00280             break ;
00281           }
00282       }
00283       _size--;
00284       node_alloc.destroy(static_cast<node*>(found.second));
00285       free.push(static_cast<node*>(found.second));
00286       if (!found.first->set_next(found.second, next))
00287         search(hint, **found.second);
00288       return true;
00289     }
00290 
00291     bool erase(const Elt& item, node_base* hint) {
00292       static_assert(!Multi::value, "Cannot erase elements on multilists");
00293       std::pair<node_base*, node_base*> found;
00294       node_base* next;
00295 
00296       while (true) {
00297         found = search(hint, item);
00298         if ((found.second == NULL)?true:(!eq(**found.second,item))) {
00299           return ;
00300         }
00301         next = found.second->get_next();
00302         if (!details::is_marked(next))
00303           if (found.second->set_next(next, details::marked(next))) {
00304             break ;
00305           }
00306       }
00307       if (found.first->set_next(found.second, next)) {
00308         --_size;
00309         node_alloc.destroy(static_cast<node*>(found.second));
00310         free.push(static_cast<node*>(found.second));
00311         return true;
00312       }
00313       else {
00314         search(hint, **found.second);
00315       }
00316     }
00317 
00318 
00319   public:
00320     bool erase(const Elt& elt) {
00321       static_assert(!Multi::value, "Cannot erase elements on multilists");
00322       return erase(elt, &head);
00323     }
00324 
00325     bool erase(const Elt& elt, const const_iterator& hint) {
00326       static_assert(!Multi::value, "Cannot erase elements on multilists");
00327       return erase(elt, hint.get_node());
00328     }
00329 
00330     bool erase(const const_iterator& it, const const_iterator& hint) {
00331       return erase(it.get_node(), hint.get_node());
00332     }
00333 
00334     bool erase(const iterator& it) {
00335       return erase(it.get_node(), &head);
00336     }
00337 
00338   private:
00339     const_iterator find(const Elt& elt, node_base*hint) const {
00340       return const_cast<sorted_forward_list*>(this)->find(elt, hint);
00341     }
00342 
00343     iterator find(const Elt& elt, node_base*hint) {
00344       std::pair<node_base*, node_base*> found = search(hint, elt);
00345       if (found.second == NULL)
00346         return iterator(NULL);
00347       if (eq(**found.second, elt))
00348         return iterator(static_cast<node*>(found.second));
00349       return iterator(NULL);
00350     }
00351 
00352   public:
00353     const_iterator find(const Elt& elt) const {
00354       return find(elt, &head);
00355     }
00356     const_iterator find(const Elt& elt, const const_iterator& hint) const {
00357       return find(elt, hint.get_node());
00358     }
00359 
00360     iterator find(const Elt& elt) {
00361       return find(elt, &head);
00362     }
00363     iterator find(const Elt& elt, const const_iterator& hint) {
00364       return find(elt, hint.get_node());
00365     }
00366 
00367     bool collect() {
00368       bool did = false;
00369       while (true) {
00370         typename sstack_t::wrapped w = free.pop();
00371         if (w.end())
00372           break ;
00373         did = true;
00374         node* n = *w;
00375         //node_alloc.destroy(n);
00376         node_alloc.deallocate(n, 1);
00377       }
00378       return did;
00379     }
00380 
00381     void clear() {
00382       node* last = NULL;
00383       for (node_base* n = head.get_next();
00384            ; n = n->get_next()) {
00385         if (last != NULL) {
00386           node_alloc.destroy(last);
00387           node_alloc.deallocate(last, 1);
00388         }
00389         if (n == NULL) {
00390           head.force_set_next(NULL);
00391           break ;
00392         }
00393         last = static_cast<node*>(n);
00394       }
00395       collect();
00396     }
00397 
00398     ~sorted_forward_list() {
00399       clear();
00400     }
00401   };
00402 
00403 }
00404 
00405 #endif